Compare commits

..

68 Commits

Author SHA1 Message Date
75f288ff02 feat: add summary index migration script. 2026-01-27 18:46:14 +08:00
e82dba104e fix: summary index migration script. 2026-01-27 18:12:31 +08:00
f2e7154c6f Merge remote-tracking branch 'origin/main' into feat/knowledgebase-summaryIndex 2026-01-27 16:48:39 +08:00
e482588ef8 fix: ConsoleDatasetListQuery request.args.to_dict() (#31598) 2026-01-27 17:12:52 +09:00
e1cb37e967 fix: summary tokens. 2026-01-27 16:11:09 +08:00
b15d9b04ae fix: summary tokens. 2026-01-27 16:10:26 +08:00
b66bd5f5a8 feat: enhance quota panel with installed providers mapping (#31546)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-01-27 15:43:37 +08:00
c8abe1c306 test: add tests for dataset document detail (#31274)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-01-27 15:43:27 +08:00
eca26a9b9b feat: Enhances OpenTelemetry node parsers (#30706)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-27 15:30:21 +08:00
febc9b930d chore: update react and next version (#31593) 2026-01-27 14:06:09 +08:00
7f873a9b2c fix: fix summary index bug. 2026-01-27 13:42:07 +08:00
lif
d13638f6e4 test: wrap test cleanup in act() to prevent window is not defined error (#31558)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-27 11:25:14 +08:00
b4eef76c14 fix: billing account deletion (#31556) 2026-01-27 11:18:23 +08:00
e01fa1b26b fix: fix summary index bug. 2026-01-27 11:15:01 +08:00
cbf7f646d9 chore(deps): bump pypdf from 6.6.0 to 6.6.2 in /api (#31568)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-27 11:06:13 +08:00
c58647d39c refactor(web): extract MCP components and add comprehensive tests (#31517)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: Claude Haiku 4.5 <noreply@anthropic.com>
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-01-27 11:05:59 +08:00
E.G
f6be9cd90d refactor: replace request.args.get with Pydantic BaseModel validation (#31104)
Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-27 10:48:42 +08:00
360f3bb32f chore(deps): bump pycryptodome from 3.19.1 to 3.23.0 in /api (#31504)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-27 10:43:05 +08:00
52176515b0 fix: fix summary index bug. 2026-01-27 10:31:56 +08:00
lif
8519b16cfc docs: add ESLint guide to AGENTS.md (#31559)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-27 09:32:55 +08:00
f00d823f9f chore: move agent notes into docstrings (#31560) 2026-01-27 09:32:26 +08:00
e48419937b feat: chatflow support multimodal (#31293)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-27 00:24:48 +08:00
ca4bb0921b fix: fix summary index bug. 2026-01-26 18:56:58 +08:00
81e269e591 fix: fix summary index bug. 2026-01-23 23:03:29 +08:00
5df75d7ffa fix: fix summary index bug. 2026-01-23 22:33:42 +08:00
ccfd3e6f6d fix: fix summary index bug. 2026-01-23 21:19:12 +08:00
328c1990ee fix: fix summary index bug. 2026-01-23 21:18:06 +08:00
76d18ca3dd fix: fix summary index bug. 2026-01-23 20:07:08 +08:00
b953e4fe9b fix: fix summary index bug. 2026-01-23 19:03:06 +08:00
9841b8c5b5 fix: fix summary index bug. 2026-01-23 16:50:46 +08:00
55245b5841 fix: fix summary index bug. 2026-01-23 15:55:40 +08:00
833db6ba0b fix: fix summary index bug. 2026-01-23 13:38:41 +08:00
87186b6c73 fix: fix summary index bug. 2026-01-23 13:09:59 +08:00
0769a1c73a Merge remote-tracking branch 'origin/main' into feat/knowledgebase-summaryIndex 2026-01-23 13:09:51 +08:00
6e56d23de9 fix: fix summary index bug. 2026-01-22 16:51:57 +08:00
e1b987b48b fix: fix summary index bug. 2026-01-22 14:05:19 +08:00
c125350fb5 Merge remote-tracking branch 'origin/main' into feat/knowledgebase-summaryIndex 2026-01-22 14:02:41 +08:00
fb51e2f36d Merge remote-tracking branch 'origin/main' into feat/knowledgebase-summaryIndex 2026-01-22 10:45:25 +08:00
5d732edbb0 Merge remote main and resolve conflicts for summaryindex feature
- Resolved conflicts in 9 task files by adopting session_factory pattern from main
- Preserved all summaryindex functionality including enable/disable logic
- Updated all task files to use session_factory.create_session() instead of db.session
- Merged new features from main (FileService, DocumentBatchDownloadZipPayload, etc.)
2026-01-21 16:03:54 +08:00
63d33fe93f fix: fix summary index bug. 2026-01-20 18:14:43 +08:00
008a5f361d fix: fix summary index bug. 2026-01-20 11:53:16 +08:00
4fb08ae7d2 fix: fix summary index bug. 2026-01-16 20:24:18 +08:00
fcb2fe55e7 fix: fix summary index bug. 2026-01-16 18:55:10 +08:00
869e70964f fix: fix summary index bug. 2026-01-15 18:09:48 +08:00
74245fea8e fix: fix summary index bug. 2026-01-15 17:57:15 +08:00
22d0c55363 fix: fix summary index bug. 2026-01-15 15:10:38 +08:00
f4d20a02aa feat: fix summary index bug. 2026-01-15 11:06:18 +08:00
7eb65b07c8 feat: Make summary index support vision, and make the code more standardized. 2026-01-14 17:52:27 +08:00
9b7e807690 feat: summary index (#30950) 2026-01-14 11:26:44 +08:00
af86f8de6f Merge branch 'feat/knowledgebase-summaryIndex' into feat/summary-index 2026-01-14 11:25:15 +08:00
ec78676949 Merge branch 'deploy/dev' into feat/summary-index 2026-01-13 21:30:50 +08:00
76da8b4ff3 Merge remote-tracking branch 'origin/deploy/dev' 2026-01-12 17:09:25 +08:00
25bfc1cc3b feat: implement Summary Index feature. 2026-01-12 16:52:21 +08:00
1fcf6e4943 Update 2025_12_16_1817-03ea244985ce_add_type_column_not_null_default_tool.py 2025-12-17 11:12:59 +08:00
f4a7efde3d update migration script. 2025-12-16 18:30:12 +08:00
38d4f0fd96 Merge remote-tracking branch 'origin/deploy/dev' 2025-12-16 18:25:54 +08:00
ec4f885dad update migration script. 2025-12-16 18:19:24 +08:00
3781c2a025 [autofix.ci] apply automated fixes 2025-12-16 08:37:32 +00:00
3782f17dc7 Optimize code. 2025-12-16 16:35:15 +08:00
29698aeed2 Merge remote-tracking branch 'origin/deploy/dev' 2025-12-16 16:26:19 +08:00
15ff8efb15 merge alembic head 2025-12-16 16:20:04 +08:00
407e1c8276 [autofix.ci] apply automated fixes 2025-12-16 08:14:05 +00:00
e368825c21 Merge remote-tracking branch 'upstream/main' 2025-12-16 15:50:49 +08:00
8dad6b6a6d Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-16 14:34:59 +08:00
2f54965a72 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-16 10:43:45 +08:00
a1a3fa0283 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:44:32 +08:00
ff7344f3d3 Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:38:44 +08:00
bcd33be22a Add "type" field to PipelineRecommendedPlugin model; Add query param "type" to recommended-plugins api. 2025-12-15 16:33:06 +08:00
235 changed files with 41513 additions and 3141 deletions

View File

@ -25,6 +25,30 @@ pnpm type-check:tsgo
pnpm test
```
### Frontend Linting
ESLint is used for frontend code quality. Available commands:
```bash
# Lint all files (report only)
pnpm lint
# Lint and auto-fix issues
pnpm lint:fix
# Lint specific files or directories
pnpm lint:fix app/components/base/button/
pnpm lint:fix app/components/base/button/index.tsx
# Lint quietly (errors only, no warnings)
pnpm lint:quiet
# Check code complexity
pnpm lint:complexity
```
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
## Testing & Quality Practices
- Follow TDD: red → green → refactor.

View File

View File

@ -1,27 +0,0 @@
# Notes: `large_language_model.py`
## Purpose
Provides the base `LargeLanguageModel` implementation used by the model runtime to invoke plugin-backed LLMs and to
bridge plugin daemon streaming semantics back into API-layer entities (`LLMResult`, `LLMResultChunk`).
## Key behaviors / invariants
- `invoke(..., stream=False)` still calls the plugin in streaming mode and then synthesizes a single `LLMResult` from
the first yielded `LLMResultChunk`.
- Plugin invocation is wrapped by `_invoke_llm_via_plugin(...)`, and `stream=False` normalization is handled by
`_normalize_non_stream_plugin_result(...)` / `_build_llm_result_from_first_chunk(...)`.
- Tool call deltas are merged incrementally via `_increase_tool_call(...)` to support multiple provider chunking
patterns (IDs anchored to first chunk, every chunk, or missing entirely).
- A tool-call delta with an empty `id` requires at least one existing tool call; otherwise we raise `ValueError` to
surface invalid delta sequences explicitly.
- Callback invocation is centralized in `_run_callbacks(...)` to ensure consistent error handling/logging.
- For compatibility with dify issue `#17799`, `prompt_messages` may be removed by the plugin daemon in chunks and must
be re-attached in this layer before callbacks/consumers use them.
- Callback hooks (`on_before_invoke`, `on_new_chunk`, `on_after_invoke`, `on_invoke_error`) must not break invocation
unless `callback.raise_error` is true.
## Test focus
- `api/tests/unit_tests/core/model_runtime/__base/test_increase_tool_call.py` validates tool-call delta merging and
patches `_gen_tool_call_id` for deterministic IDs.

View File

@ -1,97 +1,47 @@
# API Agent Guide
## Agent Notes (must-check)
## Notes for Agent (must-check)
Before you start work on any backend file under `api/`, you MUST check whether a related note exists under:
Before changing any backend code under `api/`, you MUST read the surrounding docstrings and comments. These notes contain required context (invariants, edge cases, trade-offs) and are treated as part of the spec.
- `agent-notes/<same-relative-path-as-target-file>.md`
Look for:
Rules:
- The module (file) docstring at the top of a source code file
- Docstrings on classes and functions/methods
- Paragraph/block comments for non-obvious logic
- **Path mapping**: for a target file `<path>/<name>.py`, the note must be `agent-notes/<path>/<name>.py.md` (same folder structure, same filename, plus `.md`).
- **Before working**:
- If the note exists, read it first and follow any constraints/decisions recorded there.
- If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality.
- If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases.
- **During working**:
- Keep the note in sync as you discover constraints, make decisions, or change approach.
- If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note).
- Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end).
- Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog.
- **When finishing work**:
- Update the related note(s) to reflect what changed, why, and any new edge cases/tests.
- If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance.
- Keep notes concise and accurate; they are meant to prevent repeated rediscovery.
### What to write where
## Skill Index
- Keep notes scoped: module notes cover module-wide context, class notes cover class-wide context, function/method notes cover behavioural contracts, and paragraph/block comments cover local “why”. Avoid duplicating the same content across scopes unless repetition prevents misuse.
- **Module (file) docstring**: purpose, boundaries, key invariants, and “gotchas” that a new reader must know before editing.
- Include cross-links to the key collaborators (modules/services) when discovery is otherwise hard.
- Prefer stable facts (invariants, contracts) over ephemeral “today we…” notes.
- **Class docstring**: responsibility, lifecycle, invariants, and how it should be used (or not used).
- If the class is intentionally stateful, note what state exists and what methods mutate it.
- If concurrency/async assumptions matter, state them explicitly.
- **Function/method docstring**: behavioural contract.
- Document arguments, return shape, side effects (DB writes, external I/O, task dispatch), and raised domain exceptions.
- Add examples only when they prevent misuse.
- **Paragraph/block comments**: explain *why* (trade-offs, historical constraints, surprising edge cases), not what the code already states.
- Keep comments adjacent to the logic they justify; delete or rewrite comments that no longer match reality.
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
### Rules (must follow)
### Platform Foundations
In this section, “notes” means module/class/function docstrings plus any relevant paragraph/block comments.
#### [Infrastructure Overview](agent_skills/infra.md)
- **When to read this**
- You need to understand where a feature belongs in the architecture.
- Youre wiring storage, Redis, vector stores, or OTEL.
- Youre about to add CLI commands or async jobs.
- **What it covers**
- Configuration stack (`configs/app_config.py`, remote settings)
- Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`)
- Redis conventions (`extensions/ext_redis.py`)
- Plugin runtime topology
- Vector-store factory (`core/rag/datasource/vdb/*`)
- Observability hooks
- SSRF proxy usage
- Core CLI commands
### Plugin & Extension Development
#### [Plugin Systems](agent_skills/plugin.md)
- **When to read this**
- Youre building or debugging a marketplace plugin.
- You need to know how manifests, providers, daemons, and migrations fit together.
- **What it covers**
- Plugin manifests (`core/plugin/entities/plugin.py`)
- Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands)
- Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent)
- Daemon coordination (`core/plugin/entities/plugin_daemon.py`)
- How provider registries surface capabilities to the rest of the platform
#### [Plugin OAuth](agent_skills/plugin_oauth.md)
- **When to read this**
- You must integrate OAuth for a plugin or datasource.
- Youre handling credential encryption or refresh flows.
- **Topics**
- Credential storage
- Encryption helpers (`core/helper/provider_encryption.py`)
- OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`)
- How console/API layers expose the flows
### Workflow Entry & Execution
#### [Trigger Concepts](agent_skills/trigger.md)
- **When to read this**
- Youre debugging why a workflow didnt start.
- Youre adding a new trigger type or hook.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.
- **Details**
- Start-node taxonomy
- Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`)
- Async orchestration (`services/async_workflow_service.py`, Celery queues)
- Debug event bus
- Storage/logging interactions
## General Reminders
- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
- **Before working**
- Read the notes in the area youll touch; treat them as part of the spec.
- If a docstring or comment conflicts with the current code, treat the **code as the single source of truth** and update the docstring or comment to match reality.
- If important intent/invariants/edge cases are missing, add them in the closest docstring or comment (module for overall scope, function for behaviour).
- **During working**
- Keep the notes in sync as you discover constraints, make decisions, or change approach.
- If you move/rename responsibilities across modules/classes, update the affected docstrings and comments so readers can still find the “why” and the invariants.
- Record non-obvious edge cases, trade-offs, and the test/verification plan in the nearest docstring or comment that will stay correct.
- Keep the notes **coherent**: integrate new findings into the relevant docstrings and comments; avoid append-only “recent fix” / changelog-style additions.
- **When finishing**
- Update the notes to reflect what changed, why, and any new edge cases/tests.
- Remove or rewrite any comments that could be mistaken as current guidance but no longer apply.
- Keep docstrings and comments concise and accurate; they are meant to prevent repeated rediscovery.
## Coding Style
@ -226,7 +176,7 @@ Before opening a PR / submitting:
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise comments.
- Document non-obvious behaviour with concise docstrings and comments.
### Miscellaneous

View File

@ -36,6 +36,16 @@ class NotionEstimatePayload(BaseModel):
doc_language: str = Field(default="English")
class DataSourceNotionListQuery(BaseModel):
dataset_id: str | None = Field(default=None, description="Dataset ID")
credential_id: str = Field(..., description="Credential ID", min_length=1)
datasource_parameters: dict[str, Any] | None = Field(default=None, description="Datasource parameters JSON string")
class DataSourceNotionPreviewQuery(BaseModel):
credential_id: str = Field(..., description="Credential ID", min_length=1)
register_schema_model(console_ns, NotionEstimatePayload)
@ -136,26 +146,15 @@ class DataSourceNotionListApi(Resource):
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = request.args.get("dataset_id", default=None, type=str)
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
query = DataSourceNotionListQuery.model_validate(request.args.to_dict())
# Get datasource_parameters from query string (optional, for GitHub and other datasources)
datasource_parameters_str = request.args.get("datasource_parameters", default=None, type=str)
datasource_parameters = {}
if datasource_parameters_str:
try:
datasource_parameters = json.loads(datasource_parameters_str)
if not isinstance(datasource_parameters, dict):
raise ValueError("datasource_parameters must be a JSON object.")
except json.JSONDecodeError:
raise ValueError("Invalid datasource_parameters JSON format.")
datasource_parameters = query.datasource_parameters or {}
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=credential_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
@ -164,8 +163,8 @@ class DataSourceNotionListApi(Resource):
exist_page_ids = []
with Session(db.engine) as session:
# import notion in the exist dataset
if dataset_id:
dataset = DatasetService.get_dataset(dataset_id)
if query.dataset_id:
dataset = DatasetService.get_dataset(query.dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
if dataset.data_source_type != "notion_import":
@ -173,7 +172,7 @@ class DataSourceNotionListApi(Resource):
documents = session.scalars(
select(Document).filter_by(
dataset_id=dataset_id,
dataset_id=query.dataset_id,
tenant_id=current_tenant_id,
data_source_type="notion_import",
enabled=True,
@ -240,13 +239,12 @@ class DataSourceNotionApi(Resource):
def get(self, page_id, page_type):
_, current_tenant_id = current_account_with_tenant()
credential_id = request.args.get("credential_id", default=None, type=str)
if not credential_id:
raise ValueError("Credential id is required.")
query = DataSourceNotionPreviewQuery.model_validate(request.args.to_dict())
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=current_tenant_id,
credential_id=credential_id,
credential_id=query.credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)

View File

@ -146,6 +146,7 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
@ -176,7 +177,18 @@ class IndexingEstimatePayload(BaseModel):
return result
register_schema_models(console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload)
class ConsoleDatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
include_all: bool = Field(default=False, description="Include all datasets")
ids: list[str] = Field(default_factory=list, description="Filter by dataset IDs")
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
register_schema_models(
console_ns, DatasetCreatePayload, DatasetUpdatePayload, IndexingEstimatePayload, ConsoleDatasetListQuery
)
def _get_retrieval_methods_by_vector_type(vector_type: str | None, is_mock: bool = False) -> dict[str, list[str]]:
@ -275,18 +287,19 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, current_tenant_id)
if query.ids:
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
else:
datasets, total = DatasetService.get_datasets(
page, limit, current_tenant_id, current_user, search, tag_ids, include_all
query.page,
query.limit,
current_tenant_id,
current_user,
query.keyword,
query.tag_ids,
query.include_all,
)
# check embedding setting
@ -318,7 +331,13 @@ class DatasetListApi(Resource):
else:
item.update({"partial_member_list": []})
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
response = {
"data": data,
"has_more": len(datasets) == query.limit,
"limit": query.limit,
"total": total,
"page": query.page,
}
return response, 200
@console_ns.doc("create_dataset")

View File

@ -41,10 +41,11 @@ from fields.document_fields import (
from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
from models.dataset import DocumentPipelineExecutionLog
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@ -110,6 +111,10 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
@ -132,6 +137,7 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentBatchDownloadZipPayload,
)
@ -319,6 +325,86 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
# Check if dataset has summary index enabled
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
# Filter documents that need summary calculation
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
summary_status_map = {}
if has_summary_index and document_ids_need_summary:
# Get all segments for these documents (excluding qa_model and re_segment)
segments = (
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
.where(
DocumentSegment.document_id.in_(document_ids_need_summary),
DocumentSegment.status != "re_segment",
DocumentSegment.tenant_id == current_tenant_id,
)
.all()
)
# Group segments by document_id
document_segments_map = {}
for segment in segments:
doc_id = str(segment.document_id)
if doc_id not in document_segments_map:
document_segments_map[doc_id] = []
document_segments_map[doc_id].append(segment.id)
# Get all summary records for these segments
all_segment_ids = [seg.id for seg in segments]
summaries = {}
if all_segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
)
.all()
)
summaries = {summary.chunk_id: summary.status for summary in summary_records}
# Calculate summary_index_status for each document
for doc_id in document_ids_need_summary:
segment_ids = document_segments_map.get(doc_id, [])
if not segment_ids:
# No segments, status is None (not started)
summary_status_map[doc_id] = None
continue
# Check if there are any "not_started" or "generating" status summaries
# Only check enabled=True summaries (already filtered in query)
# If segment has no summary record (summaries.get returns None),
# it means the summary is disabled (enabled=False) or not created yet, ignore it
has_pending_summaries = any(
summaries.get(segment_id) is not None # Ensure summary exists (enabled=True)
and summaries[segment_id] in ("not_started", "generating")
for segment_id in segment_ids
)
if has_pending_summaries:
# Task is still running (not started or generating)
summary_status_map[doc_id] = "SUMMARIZING"
else:
# All enabled=True summaries are "completed" or "error", task finished
# Or no enabled=True summaries exist (all disabled)
summary_status_map[doc_id] = None
# Add summary_index_status to each document
for document in documents:
if has_summary_index and document.need_summary is True:
# Get status from map, default to None (not queued yet)
document.summary_index_status = summary_status_map.get(str(document.id))
else:
# Return null if summary index is not enabled or document doesn't need summary
document.summary_index_status = None
if fetch:
for document in documents:
completed_segments = (
@ -804,6 +890,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@ -839,6 +926,7 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@ -1262,3 +1350,216 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
raise ValueError("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = (
db.session.query(Document)
.filter(
Document.id.in_(document_list),
Document.dataset_id == dataset_id,
)
.all()
)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id,
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get document
document = self.get_document(dataset_id, document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get all segments for this document
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
)
.all()
)
total_segments = len(segments)
# Get all summary records for these segments
segment_ids = [segment.id for segment in segments]
summaries = []
if segment_ids:
summaries = (
db.session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.document_id == document_id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.all()
)
# Create a mapping of chunk_id to summary
summary_map = {summary.chunk_id: summary for summary in summaries}
# Count statuses
status_counts = {
"completed": 0,
"generating": 0,
"error": 0,
"not_started": 0,
}
summary_list = []
for segment in segments:
summary = summary_map.get(segment.id)
if summary:
status = summary.status
status_counts[status] = status_counts.get(status, 0) + 1
summary_list.append(
{
"segment_id": segment.id,
"segment_position": segment.position,
"status": summary.status,
"summary_preview": (
summary.summary_content[:100] + "..."
if summary.summary_content and len(summary.summary_content) > 100
else summary.summary_content
),
"error": summary.error,
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
}
)
else:
status_counts["not_started"] += 1
summary_list.append(
{
"segment_id": segment.id,
"segment_position": segment.position,
"status": "not_started",
"summary_preview": None,
"error": None,
"created_at": None,
"updated_at": None,
}
)
return {
"total_segments": total_segments,
"summary_status": status_counts,
"summaries": summary_list,
}, 200

View File

@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields
from libs.helper import escape_like_pattern
from libs.login import current_account_with_tenant, login_required
from models.dataset import ChildChunk, DocumentSegment
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
from models.model import UploadFile
from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
@ -41,6 +41,23 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
segment_dict = marshal(segment, segment_fields)
# Query summary for this segment (only enabled summaries)
summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
)
.first()
)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@ -63,6 +80,7 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@ -180,8 +198,32 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
summary_records = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset_id,
)
.all()
)
# Only include enabled summaries
summaries = {
summary.chunk_id: summary.summary_content for summary in summary_records if summary.enabled is True
}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = marshal(segment, segment_fields)
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": marshal(segments.items, segment_fields),
"data": segments_with_summary,
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@ -327,7 +369,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@ -389,10 +431,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@ -98,12 +98,19 @@ class BedrockRetrievalPayload(BaseModel):
knowledge_id: str
class ExternalApiTemplateListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
register_schema_models(
console_ns,
ExternalKnowledgeApiPayload,
ExternalDatasetCreatePayload,
ExternalHitTestingPayload,
BedrockRetrievalPayload,
ExternalApiTemplateListQuery,
)
@ -124,19 +131,17 @@ class ExternalApiTemplateListApi(Resource):
@account_initialization_required
def get(self):
_, current_tenant_id = current_account_with_tenant()
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
query = ExternalApiTemplateListQuery.model_validate(request.args.to_dict())
external_knowledge_apis, total = ExternalDatasetService.get_external_knowledge_apis(
page, limit, current_tenant_id, search
query.page, query.limit, current_tenant_id, query.keyword
)
response = {
"data": [item.to_dict() for item in external_knowledge_apis],
"has_more": len(external_knowledge_apis) == limit,
"limit": limit,
"has_more": len(external_knowledge_apis) == query.limit,
"limit": query.limit,
"total": total,
"page": page,
"page": query.page,
}
return response, 200

View File

@ -1,6 +1,13 @@
from flask_restx import Resource
from flask_restx import Resource, fields
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required
from .. import console_ns
@ -14,13 +21,45 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@ -3,7 +3,7 @@ from typing import Any
from flask import request
from flask_restx import Resource, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import and_, select
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@ -28,6 +28,10 @@ class InstalledAppUpdatePayload(BaseModel):
is_pinned: bool | None = None
class InstalledAppsListQuery(BaseModel):
app_id: str | None = Field(default=None, description="App ID to filter by")
logger = logging.getLogger(__name__)
@ -37,13 +41,13 @@ class InstalledAppsListApi(Resource):
@account_initialization_required
@marshal_with(installed_app_list_fields)
def get(self):
app_id = request.args.get("app_id", default=None, type=str)
query = InstalledAppsListQuery.model_validate(request.args.to_dict())
current_user, current_tenant_id = current_account_with_tenant()
if app_id:
if query.app_id:
installed_apps = db.session.scalars(
select(InstalledApp).where(
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id)
and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == query.app_id)
)
).all()
else:

View File

@ -40,6 +40,7 @@ register_schema_models(
TagBasePayload,
TagBindingPayload,
TagBindingRemovePayload,
TagListQueryParam,
)

View File

@ -87,6 +87,14 @@ class TagUnbindingPayload(BaseModel):
target_id: str
class DatasetListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
include_all: bool = Field(default=False, description="Include all datasets")
tag_ids: list[str] = Field(default_factory=list, description="Filter by tag IDs")
register_schema_models(
service_api_ns,
DatasetCreatePayload,
@ -96,6 +104,7 @@ register_schema_models(
TagDeletePayload,
TagBindingPayload,
TagUnbindingPayload,
DatasetListQuery,
)
@ -113,15 +122,11 @@ class DatasetListApi(DatasetApiResource):
)
def get(self, tenant_id):
"""Resource for getting datasets."""
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
query = DatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
search = request.args.get("keyword", default=None, type=str)
tag_ids = request.args.getlist("tag_ids")
include_all = request.args.get("include_all", default="false").lower() == "true"
datasets, total = DatasetService.get_datasets(
page, limit, tenant_id, current_user, search, tag_ids, include_all
query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all
)
# check embedding setting
provider_manager = ProviderManager()
@ -147,7 +152,13 @@ class DatasetListApi(DatasetApiResource):
item["embedding_available"] = False
else:
item["embedding_available"] = True
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
response = {
"data": data,
"has_more": len(datasets) == query.limit,
"limit": query.limit,
"total": total,
"page": query.page,
}
return response, 200
@service_api_ns.expect(service_api_ns.models[DatasetCreatePayload.__name__])

View File

@ -69,7 +69,14 @@ class DocumentTextUpdate(BaseModel):
return self
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate]:
class DocumentListQuery(BaseModel):
page: int = Field(default=1, description="Page number")
limit: int = Field(default=20, description="Number of items per page")
keyword: str | None = Field(default=None, description="Search keyword")
status: str | None = Field(default=None, description="Document status filter")
for m in [ProcessRule, RetrievalModel, DocumentTextCreatePayload, DocumentTextUpdate, DocumentListQuery]:
service_api_ns.schema_model(m.__name__, m.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)) # type: ignore
@ -460,34 +467,33 @@ class DocumentListApi(DatasetApiResource):
def get(self, tenant_id, dataset_id):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
status = request.args.get("status", default=None, type=str)
query_params = DocumentListQuery.model_validate(request.args.to_dict())
dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if status:
query = DocumentService.apply_display_status_filter(query, status)
if query_params.status:
query = DocumentService.apply_display_status_filter(query, query_params.status)
if search:
search = f"%{search}%"
if query_params.keyword:
search = f"%{query_params.keyword}%"
query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
paginated_documents = db.paginate(
select=query, page=query_params.page, per_page=query_params.limit, max_per_page=100, error_out=False
)
documents = paginated_documents.items
response = {
"data": marshal(documents, document_fields),
"has_more": len(documents) == limit,
"limit": limit,
"has_more": len(documents) == query_params.limit,
"limit": query_params.limit,
"total": paginated_documents.total,
"page": page,
"page": query_params.page,
}
return response

View File

@ -236,4 +236,7 @@ class AgentChatAppRunner(AppRunner):
queue_manager=queue_manager,
stream=application_generate_entity.stream,
agent=True,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@ -79,6 +79,7 @@ class AppGenerateResponseConverter(ABC):
"document_name": resource["document_name"],
"score": resource["score"],
"content": resource["content"],
"summary": resource.get("summary"),
}
)
metadata["retriever_resources"] = updated_resources

View File

@ -1,6 +1,8 @@
import base64
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from mimetypes import guess_extension
from typing import TYPE_CHECKING, Any, Union
from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity
@ -11,10 +13,16 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent
from core.app.entities.queue_entities import (
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueMessageEndEvent,
QueueMessageFileEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.app.features.hosting_moderation.hosting_moderation import HostingModerationFeature
from core.external_data_tool.external_data_fetch import ExternalDataFetch
from core.file.enums import FileTransferMethod, FileType
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
@ -22,6 +30,7 @@ from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
@ -29,7 +38,10 @@ from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform
from models.model import App, AppMode, Message, MessageAnnotation
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
from core.file.models import File
@ -203,6 +215,9 @@ class AppRunner:
queue_manager: AppQueueManager,
stream: bool,
agent: bool = False,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
):
"""
Handle invoke result
@ -210,21 +225,41 @@ class AppRunner:
:param queue_manager: application queue manager
:param stream: stream
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
if not stream and isinstance(invoke_result, LLMResult):
self._handle_invoke_result_direct(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_direct(
invoke_result=invoke_result,
queue_manager=queue_manager,
)
elif stream and isinstance(invoke_result, Generator):
self._handle_invoke_result_stream(invoke_result=invoke_result, queue_manager=queue_manager, agent=agent)
self._handle_invoke_result_stream(
invoke_result=invoke_result,
queue_manager=queue_manager,
agent=agent,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
)
else:
raise NotImplementedError(f"unsupported invoke result type: {type(invoke_result)}")
def _handle_invoke_result_direct(self, invoke_result: LLMResult, queue_manager: AppQueueManager, agent: bool):
def _handle_invoke_result_direct(
self,
invoke_result: LLMResult,
queue_manager: AppQueueManager,
):
"""
Handle invoke result direct
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
queue_manager.publish(
@ -235,13 +270,22 @@ class AppRunner:
)
def _handle_invoke_result_stream(
self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool
self,
invoke_result: Generator[LLMResultChunk, None, None],
queue_manager: AppQueueManager,
agent: bool,
message_id: str | None = None,
user_id: str | None = None,
tenant_id: str | None = None,
):
"""
Handle invoke result
:param invoke_result: invoke result
:param queue_manager: application queue manager
:param agent: agent
:param message_id: message id for multimodal output
:param user_id: user id for multimodal output
:param tenant_id: tenant id for multimodal output
:return:
"""
model: str = ""
@ -259,12 +303,26 @@ class AppRunner:
text += message.content
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, str):
# TODO(QuantumGhost): Add multimodal output support for easy ui.
_logger.warning("received multimodal output, type=%s", type(content))
if isinstance(content, str):
text += content
elif isinstance(content, TextPromptMessageContent):
text += content.data
elif isinstance(content, ImagePromptMessageContent):
if message_id and user_id and tenant_id:
try:
self._handle_multimodal_image_content(
content=content,
message_id=message_id,
user_id=user_id,
tenant_id=tenant_id,
queue_manager=queue_manager,
)
except Exception:
_logger.exception("Failed to handle multimodal image output")
else:
_logger.warning("Received multimodal output but missing required parameters")
else:
text += content # failback to str
text += content.data if hasattr(content, "data") else str(content)
if not model:
model = result.model
@ -289,6 +347,101 @@ class AppRunner:
PublishFrom.APPLICATION_MANAGER,
)
def _handle_multimodal_image_content(
self,
content: ImagePromptMessageContent,
message_id: str,
user_id: str,
tenant_id: str,
queue_manager: AppQueueManager,
):
"""
Handle multimodal image content from LLM response.
Save the image and create a MessageFile record.
:param content: ImagePromptMessageContent instance
:param message_id: message id
:param user_id: user id
:param tenant_id: tenant id
:param queue_manager: queue manager
:return:
"""
_logger.info("Handling multimodal image content for message %s", message_id)
image_url = content.url
base64_data = content.base64_data
_logger.info("Image URL: %s, Base64 data present: %s", image_url, base64_data)
if not image_url and not base64_data:
_logger.warning("Image content has neither URL nor base64 data")
return
tool_file_manager = ToolFileManager()
# Save the image file
try:
if image_url:
# Download image from URL
_logger.info("Downloading image from URL: %s", image_url)
tool_file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=image_url,
conversation_id=None,
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
elif base64_data:
if base64_data.startswith("data:"):
base64_data = base64_data.split(",", 1)[1]
image_binary = base64.b64decode(base64_data)
mimetype = content.mime_type or "image/png"
extension = guess_extension(mimetype) or ".png"
tool_file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=image_binary,
mimetype=mimetype,
filename=f"generated_image{extension}",
)
_logger.info("Image saved successfully, tool_file_id: %s", tool_file.id)
else:
return
except Exception:
_logger.exception("Failed to save image file")
return
# Create MessageFile record
message_file = MessageFile(
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(
CreatorUserRole.ACCOUNT
if queue_manager.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}
else CreatorUserRole.END_USER
),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
# Publish QueueMessageFileEvent
queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file.id),
PublishFrom.APPLICATION_MANAGER,
)
_logger.info("QueueMessageFileEvent published for message_file_id: %s", message_file.id)
def moderation_for_inputs(
self,
*,

View File

@ -226,5 +226,10 @@ class ChatAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@ -184,5 +184,10 @@ class CompletionAppRunner(AppRunner):
# handle invoke result
self._handle_invoke_result(
invoke_result=invoke_result, queue_manager=queue_manager, stream=application_generate_entity.stream
invoke_result=invoke_result,
queue_manager=queue_manager,
stream=application_generate_entity.stream,
message_id=message.id,
user_id=application_generate_entity.user_id,
tenant_id=app_config.tenant_id,
)

View File

@ -39,6 +39,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
StreamEvent,
StreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
@ -70,6 +71,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
_precomputed_event_type: StreamEvent | None = None
def __init__(
self,
@ -342,11 +344,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent):
event_type = self._message_cycle_manager.get_message_event_type(message_id=self._message_id)
# Determine the event type once, on first LLM chunk, and reuse for subsequent chunks
if not hasattr(self, "_precomputed_event_type") or self._precomputed_event_type is None:
self._precomputed_event_type = self._message_cycle_manager.get_message_event_type(
message_id=self._message_id
)
yield self._message_cycle_manager.message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
event_type=event_type,
event_type=self._precomputed_event_type,
)
else:
yield self._agent_message_to_stream_response(

View File

@ -5,7 +5,7 @@ from threading import Thread
from typing import Union
from flask import Flask, current_app
from sqlalchemy import exists, select
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@ -30,6 +30,7 @@ from core.app.entities.task_entities import (
StreamEvent,
WorkflowTaskState,
)
from core.db.session_factory import session_factory
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
@ -57,13 +58,15 @@ class MessageCycleManager:
self._message_has_file: set[str] = set()
def get_message_event_type(self, message_id: str) -> StreamEvent:
# Fast path: cached determination from prior QueueMessageFileEvent
if message_id in self._message_has_file:
return StreamEvent.MESSAGE_FILE
with Session(db.engine, expire_on_commit=False) as session:
has_file = session.query(exists().where(MessageFile.message_id == message_id)).scalar()
# Use SQLAlchemy 2.x style session.scalar(select(...))
with session_factory.create_session() as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
if has_file:
if message_file:
self._message_has_file.add(message_id)
return StreamEvent.MESSAGE_FILE
@ -199,6 +202,8 @@ class MessageCycleManager:
message_file = session.scalar(select(MessageFile).where(MessageFile.id == event.message_file_id))
if message_file and message_file.url is not None:
self._message_has_file.add(message_file.message_id)
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension

View File

@ -1,4 +1,4 @@
from collections.abc import Generator, Mapping
from collections.abc import Generator
from typing import Any
from core.datasource.__base.datasource_plugin import DatasourcePlugin
@ -34,7 +34,7 @@ class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
def get_online_document_pages(
self,
user_id: str,
datasource_parameters: Mapping[str, Any],
datasource_parameters: dict[str, Any],
provider_type: str,
) -> Generator[OnlineDocumentPagesMessage, None, None]:
manager = PluginDatasourceManager()

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
class PreviewDetail(BaseModel):
content: str
summary: str | None = None
child_chunks: list[str] | None = None

View File

@ -311,14 +311,18 @@ class IndexingRunner:
qa_preview_texts: list[QAPreviewDetail] = []
total_segments = 0
# doc_form represents the segmentation method (general, parent-child, QA)
index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
# one extract_setting is one source document
for extract_setting in extract_settings:
# extract
processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
)
# Extract document content
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# Cleaning and segmentation
documents = index_processor.transform(
text_docs,
current_user=None,
@ -361,6 +365,12 @@ class IndexingRunner:
if doc_form and doc_form == "qa_model":
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
# Generate summary preview
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
def _extract(

View File

@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
You should edit the prompt according to the IDEAL OUTPUT."""
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
"""Summarize the following content. Extract only the key information and main points. """
"""Remove redundant details.
Requirements:
1. Write a concise summary in plain text
2. Use the same language as the input content
3. Focus on important facts, concepts, and details
4. If images are included, describe their key information
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
6. Write directly without extra words
Output only the summary text. Start summarizing now:
"""
)

View File

@ -389,15 +389,15 @@ class RetrievalService:
.all()
}
records = []
include_segment_ids = set()
segment_child_map = {}
valid_dataset_documents = {}
image_doc_ids: list[Any] = []
child_index_node_ids = []
index_node_ids = []
doc_to_document_map = {}
summary_segment_ids = set() # Track segments retrieved via summary
summary_score_map: dict[str, float] = {} # Map original_chunk_id to summary score
# First pass: collect all document IDs and identify summary documents
for document in documents:
document_id = document.metadata.get("document_id")
if document_id not in dataset_documents:
@ -408,16 +408,39 @@ class RetrievalService:
continue
valid_dataset_documents[document_id] = dataset_document
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
# Check if this is a summary document
is_summary = document.metadata.get("is_summary", False)
if is_summary:
# For summary documents, find the original chunk via original_chunk_id
original_chunk_id = document.metadata.get("original_chunk_id")
if original_chunk_id:
summary_segment_ids.add(original_chunk_id)
# Save summary's score for later use
summary_score = document.metadata.get("score")
if summary_score is not None:
try:
summary_score_float = float(summary_score)
# If the same segment has multiple summary hits, take the highest score
if original_chunk_id not in summary_score_map:
summary_score_map[original_chunk_id] = summary_score_float
else:
summary_score_map[original_chunk_id] = max(
summary_score_map[original_chunk_id], summary_score_float
)
except (ValueError, TypeError):
# Skip invalid score values
pass
continue # Skip adding to other lists for summary documents
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
child_index_node_ids.append(doc_id)
else:
doc_id = document.metadata.get("doc_id") or ""
doc_to_document_map[doc_id] = document
if document.metadata.get("doc_type") == DocType.IMAGE:
image_doc_ids.append(doc_id)
else:
@ -433,6 +456,7 @@ class RetrievalService:
attachment_map: dict[str, list[dict[str, Any]]] = {}
child_chunk_map: dict[str, list[ChildChunk]] = {}
doc_segment_map: dict[str, list[str]] = {}
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
with session_factory.create_session() as session:
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
@ -447,6 +471,7 @@ class RetrievalService:
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
else:
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
@ -470,6 +495,7 @@ class RetrievalService:
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
for index_node_segment in index_node_segments:
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
if segment_ids:
document_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
@ -481,6 +507,42 @@ class RetrievalService:
if index_node_segments:
segments.extend(index_node_segments)
# Handle summary documents: query segments by original_chunk_id
if summary_segment_ids:
summary_segment_ids_list = list(summary_segment_ids)
summary_segment_stmt = select(DocumentSegment).where(
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.id.in_(summary_segment_ids_list),
)
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
segments.extend(summary_segments)
# Add summary segment IDs to segment_ids for summary query
for seg in summary_segments:
if seg.id not in segment_ids:
segment_ids.append(seg.id)
# Batch query summaries for segments retrieved via summary (only enabled summaries)
if summary_segment_ids:
from models.dataset import DocumentSegmentSummary
summaries = (
session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
)
.all()
)
for summary in summaries:
if summary.summary_content:
segment_summary_map[summary.chunk_id] = summary.summary_content
include_segment_ids = set()
segment_child_map: dict[str, dict[str, Any]] = {}
records: list[dict[str, Any]] = []
for segment in segments:
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
@ -489,30 +551,43 @@ class RetrievalService:
if ds_dataset_document and ds_dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
# Check if this segment was retrieved via summary
# Use summary score as base score if available, otherwise 0.0
max_score = summary_score_map.get(segment.id, 0.0)
if child_chunks or attachment_infos:
child_chunk_details = []
max_score = 0.0
for child_chunk in child_chunks:
document = doc_to_document_map[child_chunk.index_node_id]
document = doc_to_document_map.get(child_chunk.index_node_id)
child_score = document.metadata.get("score", 0.0) if document else 0.0
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0) if document else 0.0,
"score": child_score,
}
child_chunk_details.append(child_chunk_detail)
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
max_score = max(max_score, child_score)
for attachment_info in attachment_infos:
file_document = doc_to_document_map[attachment_info["id"]]
max_score = max(
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
)
file_document = doc_to_document_map.get(attachment_info["id"])
if file_document:
max_score = max(
max_score, file_document.metadata.get("score", 0.0)
)
map_detail = {
"max_score": max_score,
"child_chunks": child_chunk_details,
}
segment_child_map[segment.id] = map_detail
else:
# No child chunks or attachments, use summary score if available
summary_score = summary_score_map.get(segment.id)
if summary_score is not None:
segment_child_map[segment.id] = {
"max_score": summary_score,
"child_chunks": [],
}
record: dict[str, Any] = {
"segment": segment,
}
@ -520,14 +595,23 @@ class RetrievalService:
else:
if segment.id not in include_segment_ids:
include_segment_ids.add(segment.id)
max_score = 0.0
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
# Check if this segment was retrieved via summary
# Use summary score if available (summary retrieval takes priority)
max_score = summary_score_map.get(segment.id, 0.0)
# If not retrieved via summary, use original segment's score
if segment.id not in summary_score_map:
segment_document = doc_to_document_map.get(segment.index_node_id)
if segment_document:
max_score = max(max_score, segment_document.metadata.get("score", 0.0))
# Also consider attachment scores
for attachment_info in attachment_infos:
file_doc = doc_to_document_map.get(attachment_info["id"])
if file_doc:
max_score = max(max_score, file_doc.metadata.get("score", 0.0))
record = {
"segment": segment,
"score": max_score,
@ -576,9 +660,16 @@ class RetrievalService:
else None
)
# Extract summary if this segment was retrieved via summary
summary_content = segment_summary_map.get(segment.id)
# Create RetrievalSegments object
retrieval_segment = RetrievalSegments(
segment=segment, child_chunks=child_chunks_list, score=score, files=files
segment=segment,
child_chunks=child_chunks_list,
score=score,
files=files,
summary=summary_content,
)
result.append(retrieval_segment)

View File

@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
child_chunks: list[RetrievalChildChunk] | None = None
score: float | None = None
files: list[dict[str, str | int]] | None = None
summary: str | None = None # Summary content if retrieved via summary index

View File

@ -22,3 +22,4 @@ class RetrievalSourceMetadata(BaseModel):
doc_metadata: dict[str, Any] | None = None
title: str | None = None
files: list[dict[str, Any]] | None = None
summary: str | None = None

View File

@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
import httpx
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.helper import ssrf_proxy
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.doc_type import DocType
@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
raise NotImplementedError
@abstractmethod
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
The summary can be stored in a new attribute, e.g., summary.
This method should be implemented by subclasses.
"""
raise NotImplementedError
@abstractmethod
def load(
self,

View File

@ -1,9 +1,26 @@
"""Paragraph index processor."""
import logging
import re
import uuid
from collections.abc import Mapping
from typing import Any
logger = logging.getLogger(__name__)
from core.entities.knowledge_entities import PreviewDetail
from core.file import File, FileTransferMethod, FileType, file_manager
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessageContentUnionTypes,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.provider_manager import ProviderManager
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.retrieval_service import RetrievalService
@ -17,12 +34,16 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
from models import UploadFile
from models.account import Account
from models.dataset import Dataset, DatasetProcessRule
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
class ParagraphIndexProcessor(BaseIndexProcessor):
@ -108,6 +129,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
if node_ids:
@ -227,3 +271,318 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
}
else:
raise ValueError("Chunks is not a list")
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each segment, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item."""
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = self.generate_summary(tenant_id, preview.content, summary_index_setting)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [
executor.submit(process, preview)
for preview in preview_texts
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts
@staticmethod
def generate_summary(
tenant_id: str,
text: str,
summary_index_setting: dict | None = None,
segment_id: str | None = None,
) -> tuple[str, LLMUsage]:
"""
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
and supports vision models by including images from the segment attachments or text content.
Args:
tenant_id: Tenant ID
text: Text content to summarize
summary_index_setting: Summary index configuration
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
"""
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
model_name = summary_index_setting.get("model_name")
model_provider_name = summary_index_setting.get("model_provider_name")
summary_prompt = summary_index_setting.get("summary_prompt")
# Import default summary prompt
if not summary_prompt:
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id, model_provider_name, ModelType.LLM
)
model_instance = ModelInstance(provider_model_bundle, model_name)
# Get model schema to check if vision is supported
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
# Extract images if model supports vision
image_files = []
if supports_vision:
# First, try to get images from SegmentAttachmentBinding (preferred method)
if segment_id:
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
# If no images from attachments, fall back to extracting from text
if not image_files:
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
# Build prompt messages
prompt_messages = []
if image_files:
# If we have images, create a UserPromptMessage with both text and images
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
# Add images first
for file in image_files:
try:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
)
prompt_message_contents.append(file_content)
except Exception as e:
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
continue
# Add text content
if prompt_message_contents: # Only add text if we successfully added images
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
# If image conversion failed, fall back to text-only
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
else:
# No images, use simple text prompt
prompt = f"{summary_prompt}\n{text}"
prompt_messages.append(UserPromptMessage(content=prompt))
result = model_instance.invoke_llm(prompt_messages=prompt_messages, model_parameters={}, stream=False)
summary_content = getattr(result.message, "content", "")
usage = result.usage
# Deduct quota for summary generation (same as workflow nodes)
from core.workflow.nodes.llm import llm_utils
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))
return summary_content, usage
@staticmethod
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
"""
Extract images from markdown text and convert them to File objects.
Args:
tenant_id: Tenant ID
text: Text content that may contain markdown image links
Returns:
List of File objects representing images found in the text
"""
# Extract markdown images using regex pattern
pattern = r"!\[.*?\]\((.*?)\)"
images = re.findall(pattern, text)
if not images:
return []
upload_file_id_list = []
for image in images:
# For data before v0.10.0
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For data after v0.10.0
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
match = re.search(pattern, image)
if match:
upload_file_id = match.group(1)
upload_file_id_list.append(upload_file_id)
continue
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
match = re.search(pattern, image)
if match:
# Tool files are handled differently, skip for now
continue
if not upload_file_id_list:
return []
# Get unique IDs for database query
unique_upload_file_ids = list(set(upload_file_id_list))
upload_files = (
db.session.query(UploadFile)
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
.all()
)
# Create File objects from UploadFile records
file_objects = []
for upload_file in upload_files:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
mapping = {
"upload_file_id": upload_file.id,
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
"type": FileType.IMAGE.value,
}
try:
file_obj = build_from_mapping(
mapping=mapping,
tenant_id=tenant_id,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects
@staticmethod
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
"""
Extract images from SegmentAttachmentBinding table (preferred method).
This matches how DatasetRetrieval gets segment attachments.
Args:
tenant_id: Tenant ID
segment_id: Segment ID to fetch attachments for
Returns:
List of File objects representing images found in segment attachments
"""
from sqlalchemy import select
# Query attachments from SegmentAttachmentBinding table
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.segment_id == segment_id,
SegmentAttachmentBinding.tenant_id == tenant_id,
)
).all()
if not attachments_with_bindings:
return []
file_objects = []
for _, upload_file in attachments_with_bindings:
# Only process image files
if not upload_file.mime_type or "image" not in upload_file.mime_type:
continue
try:
# Create File object directly (similar to DatasetRetrieval)
file_obj = File(
id=upload_file.id,
filename=upload_file.name,
extension="." + upload_file.extension,
mime_type=upload_file.mime_type,
tenant_id=tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
remote_url=upload_file.source_url,
related_id=upload_file.id,
size=upload_file.size,
storage_key=upload_file.key,
)
file_objects.append(file_obj)
except Exception as e:
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
continue
return file_objects

View File

@ -1,11 +1,13 @@
"""Paragraph index processor."""
import json
import logging
import uuid
from collections.abc import Mapping
from typing import Any
from configs import dify_config
from core.entities.knowledge_entities import PreviewDetail
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -25,6 +27,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument
from services.account_service import AccountService
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
class ParentChildIndexProcessor(BaseIndexProcessor):
@ -135,6 +140,29 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
@ -326,3 +354,93 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
"preview": preview,
"total_segments": len(parent_childs.parent_child_chunks),
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
and write it to the summary attribute of PreviewDetail.
In preview mode (indexing-estimate), if any summary generation fails, the method will raise an exception.
Note: For parent-child structure, we only generate summaries for parent chunks.
"""
import concurrent.futures
from flask import current_app
# Capture Flask app context for worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def process(preview: PreviewDetail) -> None:
"""Generate summary for a single preview item (parent chunk)."""
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
if flask_app:
# Ensure Flask app context in worker thread
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
)
preview.summary = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=tenant_id,
text=preview.content,
summary_index_setting=summary_index_setting,
)
preview.summary = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_texts))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_texts))) as executor:
futures = [
executor.submit(process, preview)
for preview in preview_texts
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode (indexing-estimate), if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise ValueError(error_summary)
return preview_texts

View File

@ -11,6 +11,7 @@ import pandas as pd
from flask import Flask, current_app
from werkzeug.datastructures import FileStorage
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.llm_generator import LLMGenerator
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
@ -25,9 +26,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper
from models.account import Account
from models.dataset import Dataset
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import Rule
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@ -144,6 +146,30 @@ class QAIndexProcessor(BaseIndexProcessor):
vector.create_multimodal(multimodal_documents)
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
# For disable operations, disable_summaries_for_segments is called directly in the task.
# Note: qa_model doesn't generate summaries, but we clean them for completeness
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
delete_summaries = kwargs.get("delete_summaries", False)
if delete_summaries:
if node_ids:
# Find segments by index_node_id
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
)
.all()
)
segment_ids = [segment.id for segment in segments]
if segment_ids:
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
else:
# Delete all summaries for the dataset
SummaryIndexService.delete_summaries_for_segments(dataset, None)
vector = Vector(dataset)
if node_ids:
vector.delete_by_ids(node_ids)
@ -212,6 +238,17 @@ class QAIndexProcessor(BaseIndexProcessor):
"total_segments": len(qa_chunks.qa_chunks),
}
def generate_summary_preview(
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
) -> list[PreviewDetail]:
"""
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
Note: QA model uses question-answer pairs, which don't require summary generation.
"""
# QA model doesn't generate summaries, return as-is
return preview_texts
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():

View File

@ -236,20 +236,24 @@ class DatasetRetrieval:
if records:
for record in records:
segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
segment_content = segment.get_sign_content()
# If summary exists, prepend it to the content
if record.summary:
final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
)
)
if vision_enabled:
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
@ -316,6 +320,9 @@ class DatasetRetrieval:
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, 'summary') and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True)

View File

@ -169,20 +169,24 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if records:
for record in records:
segment = record.segment
# Build content: if summary exists, add it before the segment content
if segment.answer:
document_context_list.append(
DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=record.score,
)
)
segment_content = f"question:{segment.get_sign_content()} answer:{segment.answer}"
else:
document_context_list.append(
DocumentContext(
content=segment.get_sign_content(),
score=record.score,
)
segment_content = segment.get_sign_content()
# If summary exists, prepend it to the content
if record.summary:
final_content = f"{record.summary}\n{segment_content}"
else:
final_content = segment_content
document_context_list.append(
DocumentContext(
content=final_content,
score=record.score,
)
)
if self.return_resource:
for record in records:
@ -216,6 +220,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
source.content = f"question:{segment.content} \nanswer:{segment.answer}"
else:
source.content = segment.content
# Add summary if this segment was retrieved via summary
if hasattr(record, 'summary') and record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if self.return_resource and retrieval_resource_list:

View File

@ -8,7 +8,7 @@ intercept and respond to GraphEngine events.
from abc import ABC, abstractmethod
from core.workflow.graph_engine.protocols.command_channel import CommandChannel
from core.workflow.graph_events import GraphEngineEvent
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import ReadOnlyGraphRuntimeState
@ -98,7 +98,7 @@ class GraphEngineLayer(ABC):
"""
pass
def on_node_run_start(self, node: Node) -> None: # noqa: B027
def on_node_run_start(self, node: Node) -> None:
"""
Called immediately before a node begins execution.
@ -109,9 +109,11 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance about to be executed
"""
pass
return
def on_node_run_end(self, node: Node, error: Exception | None) -> None: # noqa: B027
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""
Called after a node finishes execution.
@ -121,5 +123,6 @@ class GraphEngineLayer(ABC):
Args:
node: The node instance that just finished execution
error: Exception instance if the node failed, otherwise None
result_event: The final result event from node execution (succeeded/failed/paused), if any
"""
pass
return

View File

@ -1,61 +0,0 @@
"""
Node-level OpenTelemetry parser interfaces and defaults.
"""
import json
from typing import Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(self, *, node: Node, span: "Span", error: Exception | None) -> None:
self._delegate.parse(node=node, span=span, error=error)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute("tool.provider.id", tool_data.provider_id)
span.set_attribute("tool.provider.type", tool_data.provider_type.value)
span.set_attribute("tool.provider.name", tool_data.provider_name)
span.set_attribute("tool.name", tool_data.tool_name)
span.set_attribute("tool.label", tool_data.tool_label)
if tool_data.plugin_unique_identifier:
span.set_attribute("tool.plugin.id", tool_data.plugin_unique_identifier)
if tool_data.credential_id:
span.set_attribute("tool.credential.id", tool_data.credential_id)
if tool_data.tool_configurations:
span.set_attribute("tool.config", json.dumps(tool_data.tool_configurations, ensure_ascii=False))

View File

@ -18,12 +18,15 @@ from typing_extensions import override
from configs import dify_config
from core.workflow.enums import NodeType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_engine.layers.node_parsers import (
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser import (
DefaultNodeOTelParser,
LLMNodeOTelParser,
NodeOTelParser,
RetrievalNodeOTelParser,
ToolNodeOTelParser,
)
from core.workflow.nodes.base.node import Node
from extensions.otel.runtime import is_instrument_flag_enabled
logger = logging.getLogger(__name__)
@ -72,6 +75,8 @@ class ObservabilityLayer(GraphEngineLayer):
"""Initialize parser registry for node types."""
self._parsers = {
NodeType.TOOL: ToolNodeOTelParser(),
NodeType.LLM: LLMNodeOTelParser(),
NodeType.KNOWLEDGE_RETRIEVAL: RetrievalNodeOTelParser(),
}
def _get_parser(self, node: Node) -> NodeOTelParser:
@ -119,7 +124,9 @@ class ObservabilityLayer(GraphEngineLayer):
logger.warning("Failed to create OpenTelemetry span for node %s: %s", node.id, e)
@override
def on_node_run_end(self, node: Node, error: Exception | None) -> None:
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""
Called when a node finishes execution.
@ -139,7 +146,7 @@ class ObservabilityLayer(GraphEngineLayer):
span = node_context.span
parser = self._get_parser(node)
try:
parser.parse(node=node, span=span, error=error)
parser.parse(node=node, span=span, error=error, result_event=result_event)
span.end()
finally:
token = node_context.token

View File

@ -17,7 +17,7 @@ from typing_extensions import override
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
from core.workflow.nodes.base.node import Node
from .ready_queue import ReadyQueue
@ -131,6 +131,7 @@ class Worker(threading.Thread):
node.ensure_execution_id()
error: Exception | None = None
result_event: GraphNodeEventBase | None = None
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
@ -140,22 +141,26 @@ class Worker(threading.Thread):
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
self._invoke_node_run_end_hooks(node, error, result_event)
else:
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()
for event in node_events:
self._event_queue.put(event)
if is_node_result_event(event):
result_event = event
except Exception as exc:
error = exc
raise
finally:
self._invoke_node_run_end_hooks(node, error)
self._invoke_node_run_end_hooks(node, error, result_event)
def _invoke_node_run_start_hooks(self, node: Node) -> None:
"""Invoke on_node_run_start hooks for all layers."""
@ -166,11 +171,13 @@ class Worker(threading.Thread):
# Silently ignore layer errors to prevent disrupting node execution
continue
def _invoke_node_run_end_hooks(self, node: Node, error: Exception | None) -> None:
def _invoke_node_run_end_hooks(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
"""Invoke on_node_run_end hooks for all layers."""
for layer in self._layers:
try:
layer.on_node_run_end(node, error)
layer.on_node_run_end(node, error, result_event)
except Exception:
# Silently ignore layer errors to prevent disrupting node execution
continue

View File

@ -44,6 +44,7 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
is_node_result_event,
)
__all__ = [
@ -73,4 +74,5 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"is_node_result_event",
]

View File

@ -56,3 +56,26 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
def is_node_result_event(event: GraphNodeEventBase) -> bool:
"""
Check if an event is a final result event from node execution.
A result event indicates the completion of a node execution and contains
runtime information such as inputs, outputs, or error details.
Args:
event: The event to check
Returns:
True if the event is a node result event (succeeded/failed/paused), False otherwise
"""
return isinstance(
event,
(
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
),
)

View File

@ -62,6 +62,21 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
inputs = {"variable_selector": variable_selector}
process_data = {"documents": value if isinstance(value, list) else [value]}
# Ensure storage_key is loaded for File objects
files_to_check = value if isinstance(value, list) else [value]
files_needing_storage_key = [
f for f in files_to_check if isinstance(f, File) and not f.storage_key and f.related_id
]
if files_needing_storage_key:
from sqlalchemy.orm import Session
from extensions.ext_database import db
from factories.file_factory import StorageKeyLoader
with Session(bind=db.engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
storage_key_loader.load_storage_keys(files_needing_storage_key)
try:
if isinstance(value, list):
extracted_text_list = list(map(_extract_text_from_file, value))
@ -415,6 +430,16 @@ def _download_file_content(file: File) -> bytes:
response.raise_for_status()
return response.content
else:
# Check if storage_key is set
if not file.storage_key:
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
# Check if file exists before downloading
from extensions.ext_storage import storage
if not storage.exists(file.storage_key):
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
return file_manager.download(file)
except Exception as e:
raise FileDownloadError(f"Error downloading file: {str(e)}") from e

View File

@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
type: str = "knowledge-index"
chunk_structure: str
index_chunk_variable_selector: list[str]
indexing_technique: str | None = None
summary_index_setting: dict | None = None

View File

@ -1,9 +1,11 @@
import concurrent.futures
import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any
from flask import current_app
from sqlalchemy import func, select
from core.app.entities.app_invoke_entities import InvokeFrom
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.template import Template
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
from services.summary_index_service import SummaryIndexService
from tasks.generate_summary_index_task import generate_summary_index_task
from .entities import KnowledgeIndexNodeData
from .exc import (
@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
# index knowledge
try:
if is_preview:
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
# Preview mode: generate summaries for chunks directly without saving to database
# Format preview and generate summaries on-the-fly
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
# or fallback to dataset if not available in node_data
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
outputs = self._get_preview_output_with_summaries(
node_data.chunk_structure,
chunks,
dataset=dataset,
indexing_technique=indexing_technique,
summary_index_setting=summary_index_setting,
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@ -148,6 +165,11 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
)
.scalar()
)
# Update need_summary based on dataset's summary_index_setting
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
document.need_summary = True
else:
document.need_summary = False
db.session.add(document)
# update document segment status
db.session.query(DocumentSegment).where(
@ -163,6 +185,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
db.session.commit()
# Generate summary index if enabled
self._handle_summary_index_generation(dataset, document, variable_pool)
return {
"dataset_id": ds_id_value,
"dataset_name": dataset_name_value,
@ -173,9 +198,307 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
"display_status": "completed",
}
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
def _handle_summary_index_generation(
self,
dataset: Dataset,
document: Document,
variable_pool: VariablePool,
) -> None:
"""
Handle summary index generation based on mode (debug/preview or production).
Args:
dataset: Dataset containing the document
document: Document to generate summaries for
variable_pool: Variable pool to check invoke_from
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
return
# Skip qa_model documents
if document.doc_form == "qa_model":
return
# Determine if in preview/debug mode
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
# Determine if only parent chunks should be processed
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
if is_preview:
try:
# Query segments that need summary generation
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
segments = query.all()
if not segments:
logger.info("No segments found for document %s", document.id)
return
# Filter segments based on mode
segments_to_process = []
for segment in segments:
# Skip if summary already exists
existing_summary = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
.first()
)
if existing_summary:
continue
# For parent-child mode, all segments are parent chunks, so process all
segments_to_process.append(segment)
if not segments_to_process:
logger.info("No segments need summary generation for document %s", document.id)
return
# Use ThreadPoolExecutor for concurrent generation
flask_app = current_app._get_current_object() # type: ignore
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
def process_segment(segment: DocumentSegment) -> None:
"""Process a single segment in a thread with Flask app context."""
with flask_app.app_context():
try:
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
except Exception:
logger.exception(
"Failed to generate summary for segment %s",
segment.id,
)
# Continue processing other segments
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
# Wait for all tasks to complete
concurrent.futures.wait(futures)
logger.info(
"Successfully generated summary index for %s segments in document %s",
len(segments_to_process),
document.id,
)
except Exception:
logger.exception("Failed to generate summary index for document %s", document.id)
# Don't fail the entire indexing process if summary generation fails
else:
# Production mode: asynchronous generation
logger.info(
"Queuing summary index generation task for document %s (production mode)",
document.id,
)
try:
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info("Summary index generation task queued for document %s", document.id)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document.id,
)
# Don't fail the entire indexing process if task queuing fails
def _get_preview_output_with_summaries(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset,
indexing_technique: str | None = None,
summary_index_setting: dict | None = None,
) -> Mapping[str, Any]:
"""
Generate preview output with summaries for chunks in preview mode.
This method generates summaries on-the-fly without saving to database.
Args:
chunk_structure: Chunk structure type
chunks: Chunks to generate preview for
dataset: Dataset object (for tenant_id)
indexing_technique: Indexing technique from node config or dataset
summary_index_setting: Summary index setting from node config or dataset
"""
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
return index_processor.format_preview(chunks)
preview_output = index_processor.format_preview(chunks)
# Check if summary index is enabled
if indexing_technique != "high_quality":
return preview_output
if not summary_index_setting or not summary_index_setting.get("enable"):
return preview_output
# Generate summaries for chunks
if "preview" in preview_output and isinstance(preview_output["preview"], list):
chunk_count = len(preview_output["preview"])
logger.info(
"Generating summaries for %s chunks in preview mode (dataset: %s)",
chunk_count,
dataset.id,
)
# Use ParagraphIndexProcessor's generate_summary method
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
# Get Flask app for application context in worker threads
flask_app = None
try:
flask_app = current_app._get_current_object() # type: ignore
except RuntimeError:
logger.warning("No Flask application context available, summary generation may fail")
def generate_summary_for_chunk(preview_item: dict) -> None:
"""Generate summary for a single chunk."""
if "content" in preview_item:
# Set Flask application context in worker thread
if flask_app:
with flask_app.app_context():
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
else:
# Fallback: try without app context (may fail)
summary, _ = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=preview_item["content"],
summary_index_setting=summary_index_setting,
)
if summary:
preview_item["summary"] = summary
# Generate summaries concurrently using ThreadPoolExecutor
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
errors: list[Exception] = []
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
futures = [
executor.submit(generate_summary_for_chunk, preview_item)
for preview_item in preview_output["preview"]
]
# Wait for all tasks to complete with timeout
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
# Cancel tasks that didn't complete in time
if not_done:
timeout_error_msg = (
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
)
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
# In preview mode, timeout is also an error
errors.append(TimeoutError(timeout_error_msg))
for future in not_done:
future.cancel()
# Wait a bit for cancellation to take effect
concurrent.futures.wait(not_done, timeout=5)
# Collect exceptions from completed futures
for future in done:
try:
future.result() # This will raise any exception that occurred
except Exception as e:
logger.exception("Error in summary generation future")
errors.append(e)
# In preview mode, if there are any errors, fail the request
if errors:
error_messages = [str(e) for e in errors]
error_summary = (
f"Failed to generate summaries for {len(errors)} chunk(s). "
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
)
if len(errors) > 3:
error_summary += f" (and {len(errors) - 3} more)"
logger.error("Summary generation failed in preview mode: %s", error_summary)
raise KnowledgeIndexNodeError(error_summary)
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
logger.info(
"Completed summary generation for preview chunks: %s/%s succeeded",
completed_count,
len(preview_output["preview"]),
)
return preview_output
def _get_preview_output(
self,
chunk_structure: str,
chunks: Any,
dataset: Dataset | None = None,
variable_pool: VariablePool | None = None,
) -> Mapping[str, Any]:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
preview_output = index_processor.format_preview(chunks)
# If dataset is provided, try to enrich preview with summaries
if dataset and variable_pool:
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
if document_id:
document = db.session.query(Document).filter_by(id=document_id.value).first()
if document:
# Query summaries for this document
summaries = (
db.session.query(DocumentSegmentSummary)
.filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True,
)
.all()
)
if summaries:
# Create a map of segment content to summary for matching
# Use content matching as chunks in preview might not be indexed yet
summary_by_content = {}
for summary in summaries:
segment = (
db.session.query(DocumentSegment)
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
.first()
)
if segment:
# Normalize content for matching (strip whitespace)
normalized_content = segment.content.strip()
summary_by_content[normalized_content] = summary.summary_content
# Enrich preview with summaries by content matching
if "preview" in preview_output and isinstance(preview_output["preview"], list):
matched_count = 0
for preview_item in preview_output["preview"]:
if "content" in preview_item:
# Normalize content for matching
normalized_chunk_content = preview_item["content"].strip()
if normalized_chunk_content in summary_by_content:
preview_item["summary"] = summary_by_content[normalized_chunk_content]
matched_count += 1
if matched_count > 0:
logger.info(
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
matched_count,
dataset.id,
document.id,
)
return preview_output
@classmethod
def version(cls) -> str:

View File

@ -419,6 +419,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(

View File

@ -685,6 +685,8 @@ class LLMNode(Node[LLMNodeData]):
if "content" not in item:
raise InvalidContextStructureError(f"Invalid context structure: {item}")
if item.get("summary"):
context_str += item["summary"] + "\n"
context_str += item["content"] + "\n"
retriever_resource = self._convert_to_original_retriever_resource(item)
@ -746,6 +748,7 @@ class LLMNode(Node[LLMNodeData]):
page=metadata.get("page"),
doc_metadata=metadata.get("doc_metadata"),
files=context_dict.get("files"),
summary=context_dict.get("summary"),
)
return source

View File

@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
imports = [
"tasks.async_workflow_tasks", # trigger workers
"tasks.trigger_processing_tasks", # async trigger processing
"tasks.generate_summary_index_task", # summary index generation
"tasks.regenerate_summary_index_task", # summary index regeneration
]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME

View File

@ -0,0 +1,20 @@
"""
OpenTelemetry node parsers for workflow nodes.
This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
__all__ = [
"DefaultNodeOTelParser",
"LLMNodeOTelParser",
"NodeOTelParser",
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
]

View File

@ -0,0 +1,117 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
"""
import json
from typing import Any, Protocol
from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from core.file.models import File
from core.variables import Segment
from core.workflow.enums import NodeType
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
Handles:
- Segment types (ArrayFileSegment, FileSegment, etc.) - converts to their value
- File objects - converts to dict using to_dict()
- BaseModel objects - converts using model_dump()
- Other types - falls back to str() representation
Args:
obj: Object to serialize
ensure_ascii: Whether to ensure ASCII encoding
Returns:
JSON string representation of the object
"""
def _convert_value(value: Any) -> Any:
"""Recursively convert non-serializable values."""
if value is None:
return None
if isinstance(value, (bool, int, float, str)):
return value
if isinstance(value, Segment):
# Convert Segment to its underlying value
return _convert_value(value.value)
if isinstance(value, File):
# Convert File to dict
return value.to_dict()
if isinstance(value, BaseModel):
# Convert Pydantic model to dict
return _convert_value(value.model_dump(mode="json"))
if isinstance(value, dict):
return {k: _convert_value(v) for k, v in value.items()}
if isinstance(value, (list, tuple)):
return [_convert_value(item) for item in value]
# Fallback to string representation for unknown types
return str(value)
try:
converted = _convert_value(obj)
return json.dumps(converted, ensure_ascii=ensure_ascii)
except (TypeError, ValueError) as e:
# If conversion still fails, return error message as string
return json.dumps(
{"error": f"Failed to serialize: {type(obj).__name__}", "message": str(e)}, ensure_ascii=ensure_ascii
)
class NodeOTelParser(Protocol):
"""Parser interface for node-specific OpenTelemetry enrichment."""
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None: ...
class DefaultNodeOTelParser:
"""Fallback parser used when no node-specific parser is registered."""
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
span.set_attribute("node.id", node.id)
if node.execution_id:
span.set_attribute("node.execution_id", node.execution_id)
if hasattr(node, "node_type") and node.node_type:
span.set_attribute("node.type", node.node_type.value)
span.set_attribute(GenAIAttributes.FRAMEWORK, "dify")
node_type = getattr(node, "node_type", None)
if isinstance(node_type, NodeType):
if node_type == NodeType.LLM:
span.set_attribute(GenAIAttributes.SPAN_KIND, "LLM")
elif node_type == NodeType.KNOWLEDGE_RETRIEVAL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "RETRIEVER")
elif node_type == NodeType.TOOL:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TOOL")
else:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
else:
span.set_attribute(GenAIAttributes.SPAN_KIND, "TASK")
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)
span.set_status(Status(StatusCode.ERROR, str(error)))
else:
span.set_status(Status(StatusCode.OK))

View File

@ -0,0 +1,155 @@
"""
Parser for LLM nodes that captures LLM-specific metadata.
"""
import logging
from collections.abc import Mapping
from typing import Any
from opentelemetry.trace import Span
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import LLMAttributes
logger = logging.getLogger(__name__)
def _format_input_messages(process_data: Mapping[str, Any]) -> str:
"""
Format input messages from process_data for LLM spans.
Args:
process_data: Process data containing prompts
Returns:
JSON string of formatted input messages
"""
try:
if not isinstance(process_data, dict):
return safe_json_dumps([])
prompts = process_data.get("prompts", [])
if not prompts:
return safe_json_dumps([])
valid_roles = {"system", "user", "assistant", "tool"}
input_messages = []
for prompt in prompts:
if not isinstance(prompt, dict):
continue
role = prompt.get("role", "")
text = prompt.get("text", "")
if not role or role not in valid_roles:
continue
if text:
message = {"role": role, "parts": [{"type": "text", "content": text}]}
input_messages.append(message)
return safe_json_dumps(input_messages)
except Exception as e:
logger.warning("Failed to format input messages: %s", e, exc_info=True)
return safe_json_dumps([])
def _format_output_messages(outputs: Mapping[str, Any]) -> str:
"""
Format output messages from outputs for LLM spans.
Args:
outputs: Output data containing text and finish_reason
Returns:
JSON string of formatted output messages
"""
try:
if not isinstance(outputs, dict):
return safe_json_dumps([])
text = outputs.get("text", "")
finish_reason = outputs.get("finish_reason", "")
if not text:
return safe_json_dumps([])
valid_finish_reasons = {"stop", "length", "content_filter", "tool_call", "error"}
if finish_reason not in valid_finish_reasons:
finish_reason = "stop"
output_message = {
"role": "assistant",
"parts": [{"type": "text", "content": text}],
"finish_reason": finish_reason,
}
return safe_json_dumps([output_message])
except Exception as e:
logger.warning("Failed to format output messages: %s", e, exc_info=True)
return safe_json_dumps([])
class LLMNodeOTelParser:
"""Parser for LLM nodes that captures LLM-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
if not result_event or not result_event.node_run_result:
return
node_run_result = result_event.node_run_result
process_data = node_run_result.process_data or {}
outputs = node_run_result.outputs or {}
# Extract usage data (from process_data or outputs)
usage_data = process_data.get("usage") or outputs.get("usage") or {}
# Model and provider information
model_name = process_data.get("model_name") or ""
model_provider = process_data.get("model_provider") or ""
if model_name:
span.set_attribute(LLMAttributes.REQUEST_MODEL, model_name)
if model_provider:
span.set_attribute(LLMAttributes.PROVIDER_NAME, model_provider)
# Token usage
if usage_data:
prompt_tokens = usage_data.get("prompt_tokens", 0)
completion_tokens = usage_data.get("completion_tokens", 0)
total_tokens = usage_data.get("total_tokens", 0)
span.set_attribute(LLMAttributes.USAGE_INPUT_TOKENS, prompt_tokens)
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
# Prompts and completion
prompts = process_data.get("prompts", [])
if prompts:
prompts_json = safe_json_dumps(prompts)
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
text_output = str(outputs.get("text", ""))
if text_output:
span.set_attribute(LLMAttributes.COMPLETION, text_output)
# Finish reason
finish_reason = outputs.get("finish_reason") or ""
if finish_reason:
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
# Structured input/output messages
gen_ai_input_message = _format_input_messages(process_data)
gen_ai_output_message = _format_output_messages(outputs)
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)

View File

@ -0,0 +1,105 @@
"""
Parser for knowledge retrieval nodes that captures retrieval-specific metadata.
"""
import logging
from collections.abc import Sequence
from typing import Any
from opentelemetry.trace import Span
from core.variables import Segment
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import RetrieverAttributes
logger = logging.getLogger(__name__)
def _format_retrieval_documents(retrieval_documents: list[Any]) -> list:
"""
Format retrieval documents for semantic conventions.
Args:
retrieval_documents: List of retrieval document dictionaries
Returns:
List of formatted semantic documents
"""
try:
if not isinstance(retrieval_documents, list):
return []
semantic_documents = []
for doc in retrieval_documents:
if not isinstance(doc, dict):
continue
metadata = doc.get("metadata", {})
content = doc.get("content", "")
title = doc.get("title", "")
score = metadata.get("score", 0.0)
document_id = metadata.get("document_id", "")
semantic_metadata = {}
if title:
semantic_metadata["title"] = title
if metadata.get("source"):
semantic_metadata["source"] = metadata["source"]
elif metadata.get("_source"):
semantic_metadata["source"] = metadata["_source"]
if metadata.get("doc_metadata"):
doc_metadata = metadata["doc_metadata"]
if isinstance(doc_metadata, dict):
semantic_metadata.update(doc_metadata)
semantic_doc = {
"document": {"content": content, "metadata": semantic_metadata, "score": score, "id": document_id}
}
semantic_documents.append(semantic_doc)
return semantic_documents
except Exception as e:
logger.warning("Failed to format retrieval documents: %s", e, exc_info=True)
return []
class RetrievalNodeOTelParser:
"""Parser for knowledge retrieval nodes that captures retrieval-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
if not result_event or not result_event.node_run_result:
return
node_run_result = result_event.node_run_result
inputs = node_run_result.inputs or {}
outputs = node_run_result.outputs or {}
# Extract query from inputs
query = str(inputs.get("query", "")) if inputs else ""
if query:
span.set_attribute(RetrieverAttributes.QUERY, query)
# Extract and format retrieval documents from outputs
result_value = outputs.get("result") if outputs else None
retrieval_documents: list[Any] = []
if result_value:
value_to_check = result_value
if isinstance(result_value, Segment):
value_to_check = result_value.value
if isinstance(value_to_check, (list, Sequence)):
retrieval_documents = list(value_to_check)
if retrieval_documents:
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)

View File

@ -0,0 +1,47 @@
"""
Parser for tool nodes that captures tool-specific metadata.
"""
from opentelemetry.trace import Span
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.semconv.gen_ai import ToolAttributes
class ToolNodeOTelParser:
"""Parser for tool nodes that captures tool-specific metadata."""
def __init__(self) -> None:
self._delegate = DefaultNodeOTelParser()
def parse(
self, *, node: Node, span: "Span", error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
self._delegate.parse(node=node, span=span, error=error, result_event=result_event)
tool_data = getattr(node, "_node_data", None)
if not isinstance(tool_data, ToolNodeData):
return
span.set_attribute(ToolAttributes.TOOL_NAME, node.title)
span.set_attribute(ToolAttributes.TOOL_TYPE, tool_data.provider_type.value)
# Extract tool info from metadata (consistent with aliyun_trace)
tool_info = {}
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.metadata:
tool_info = node_run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
if tool_info:
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))

View File

@ -1,6 +1,13 @@
"""Semantic convention shortcuts for Dify-specific spans."""
from .dify import DifySpanAttributes
from .gen_ai import GenAIAttributes
from .gen_ai import ChainAttributes, GenAIAttributes, LLMAttributes, RetrieverAttributes, ToolAttributes
__all__ = ["DifySpanAttributes", "GenAIAttributes"]
__all__ = [
"ChainAttributes",
"DifySpanAttributes",
"GenAIAttributes",
"LLMAttributes",
"RetrieverAttributes",
"ToolAttributes",
]

View File

@ -62,3 +62,37 @@ class ToolAttributes:
TOOL_CALL_RESULT = "gen_ai.tool.call.result"
"""Tool invocation result."""
class LLMAttributes:
"""LLM operation attribute keys."""
REQUEST_MODEL = "gen_ai.request.model"
"""Model identifier."""
PROVIDER_NAME = "gen_ai.provider.name"
"""Provider name."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens."""
PROMPT = "gen_ai.prompt"
"""Prompt text."""
COMPLETION = "gen_ai.completion"
"""Completion text."""
RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
"""Finish reason for the response."""
INPUT_MESSAGE = "gen_ai.input.messages"
"""Input messages in structured format."""
OUTPUT_MESSAGE = "gen_ai.output.messages"
"""Output messages in structured format."""

View File

@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
"score_threshold_enabled": fields.Boolean,
"score_threshold": fields.Float,
}
dataset_summary_index_fields = {
"enable": fields.Boolean,
"model_name": fields.String,
"model_provider_name": fields.String,
"summary_prompt": fields.String,
}
external_retrieval_model_fields = {
"top_k": fields.Integer,
"score_threshold": fields.Float,
@ -83,6 +91,7 @@ dataset_detail_fields = {
"embedding_model_provider": fields.String,
"embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
"tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),

View File

@ -33,6 +33,11 @@ document_fields = {
"hit_count": fields.Integer,
"doc_form": fields.String,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
# Summary index generation status:
# "SUMMARIZING" (when task is queued and generating)
"summary_index_status": fields.String,
# Whether this document needs summary index generation
"need_summary": fields.Boolean,
}
document_with_segments_fields = {
@ -60,6 +65,10 @@ document_with_segments_fields = {
"completed_segments": fields.Integer,
"total_segments": fields.Integer,
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
# Summary index generation status:
# "SUMMARIZING" (when task is queued and generating)
"summary_index_status": fields.String,
"need_summary": fields.Boolean, # Whether this document needs summary index generation
}
dataset_and_document_fields = {

View File

@ -58,4 +58,5 @@ hit_testing_record_fields = {
"score": fields.Float,
"tsne_position": fields.Raw,
"files": fields.List(fields.Nested(files_fields)),
"summary": fields.String, # Summary content if retrieved via summary index
}

View File

@ -36,6 +36,7 @@ class RetrieverResource(ResponseModel):
segment_position: int | None = None
index_node_hash: str | None = None
content: str | None = None
summary: str | None = None
created_at: int | None = None
@field_validator("created_at", mode="before")

View File

@ -49,4 +49,5 @@ segment_fields = {
"stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"attachments": fields.List(fields.Nested(attachment_fields)),
"summary": fields.String, # Summary content for the segment
}

View File

@ -0,0 +1,71 @@
"""add summary index feature
Revision ID: 788d3099ae3a
Revises: 9d77545f524e
Create Date: 2026-01-27 18:15:45.277928
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '788d3099ae3a'
down_revision = '9d77545f524e'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('document_segment_summary',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
sa.Column('summary_content', models.types.LongText(), nullable=True),
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
sa.Column('tokens', sa.Integer(), nullable=True),
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
sa.Column('error', models.types.LongText(), nullable=True),
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
sa.Column('disabled_at', sa.DateTime(), nullable=True),
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='document_segment_summary_pkey')
)
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
batch_op.create_index('document_segment_summary_chunk_id_idx', ['chunk_id'], unique=False)
batch_op.create_index('document_segment_summary_dataset_id_idx', ['dataset_id'], unique=False)
batch_op.create_index('document_segment_summary_document_id_idx', ['document_id'], unique=False)
batch_op.create_index('document_segment_summary_status_idx', ['status'], unique=False)
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('documents', schema=None) as batch_op:
batch_op.drop_column('need_summary')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('summary_index_setting')
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
batch_op.drop_index('document_segment_summary_status_idx')
batch_op.drop_index('document_segment_summary_document_id_idx')
batch_op.drop_index('document_segment_summary_dataset_id_idx')
batch_op.drop_index('document_segment_summary_chunk_id_idx')
op.drop_table('document_segment_summary')
# ### end Alembic commands ###

View File

@ -72,6 +72,7 @@ class Dataset(Base):
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
collection_binding_id = mapped_column(StringUUID, nullable=True)
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
icon_info = mapped_column(AdjustedJSON, nullable=True)
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
@ -419,6 +420,7 @@ class Document(Base):
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
doc_language = mapped_column(String(255), nullable=True)
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@ -1575,3 +1577,36 @@ class SegmentAttachmentBinding(Base):
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
class DocumentSegmentSummary(Base):
__tablename__ = "document_segment_summary"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="document_segment_summary_pkey"),
sa.Index("document_segment_summary_dataset_id_idx", "dataset_id"),
sa.Index("document_segment_summary_document_id_idx", "document_id"),
sa.Index("document_segment_summary_chunk_id_idx", "chunk_id"),
sa.Index("document_segment_summary_status_idx", "status"),
)
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# corresponds to DocumentSegment.id or parent chunk id
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
tokens: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
error: Mapped[str] = mapped_column(LongText, nullable=True)
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
disabled_by = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
def __repr__(self):
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"

View File

@ -64,7 +64,7 @@ dependencies = [
"pandas[excel,output-formatting,performance]~=2.2.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"pycryptodome==3.19.1",
"pycryptodome==3.23.0",
"pydantic~=2.11.4",
"pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.11.0",

View File

@ -131,7 +131,7 @@ class BillingService:
headers = {"Content-Type": "application/json", "Billing-Api-Secret-Key": cls.secret_key}
url = f"{cls.base_url}{endpoint}"
response = httpx.request(method, url, json=json, params=params, headers=headers)
response = httpx.request(method, url, json=json, params=params, headers=headers, follow_redirects=True)
if method == "GET" and response.status_code != httpx.codes.OK:
raise ValueError("Unable to retrieve billing information. Please try again later or contact support.")
if method == "PUT":
@ -143,6 +143,9 @@ class BillingService:
raise ValueError("Invalid arguments.")
if method == "POST" and response.status_code != httpx.codes.OK:
raise ValueError(f"Unable to send request to {url}. Please try again later or contact support.")
if method == "DELETE" and response.status_code != httpx.codes.OK:
logger.error("billing_service: DELETE response: %s %s", response.status_code, response.text)
raise ValueError(f"Unable to process delete request {url}. Please try again later or contact support.")
return response.json()
@staticmethod
@ -165,7 +168,7 @@ class BillingService:
def delete_account(cls, account_id: str):
"""Delete account."""
params = {"account_id": account_id}
return cls._send_request("DELETE", "/account/", params=params)
return cls._send_request("DELETE", "/account", params=params)
@classmethod
def is_email_in_freeze(cls, email: str) -> bool:

View File

@ -89,6 +89,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
@ -476,6 +477,11 @@ class DatasetService:
if external_retrieval_model:
dataset.retrieval_model = external_retrieval_model
# Update summary index setting if provided
summary_index_setting = data.get("summary_index_setting", None)
if summary_index_setting is not None:
dataset.summary_index_setting = summary_index_setting
# Update basic dataset properties
dataset.name = data.get("name", dataset.name)
dataset.description = data.get("description", dataset.description)
@ -564,6 +570,9 @@ class DatasetService:
# update Retrieval model
if data.get("retrieval_model"):
filtered_data["retrieval_model"] = data["retrieval_model"]
# update summary index setting
if data.get("summary_index_setting"):
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
# update icon info
if data.get("icon_info"):
filtered_data["icon_info"] = data.get("icon_info")
@ -572,12 +581,27 @@ class DatasetService:
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
db.session.commit()
# Reload dataset to get updated values
db.session.refresh(dataset)
# update pipeline knowledge base node data
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
# Trigger vector index task if indexing technique changed
if action:
deal_dataset_vector_index_task.delay(dataset.id, action)
# If embedding_model changed, also regenerate summary vectors
if action == "update":
regenerate_summary_index_task.delay(
dataset.id,
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
# Note: summary_index_setting changes do not trigger automatic regeneration of existing summaries.
# The new setting will only apply to:
# 1. New documents added after the setting change
# 2. Manual summary generation requests
return dataset
@ -616,6 +640,7 @@ class DatasetService:
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
node["data"] = knowledge_index_node_data
updated = True
except Exception:
@ -854,6 +879,58 @@ class DatasetService:
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
@staticmethod
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
"""
Check if summary_index_setting model (model_name or model_provider_name) has changed.
Args:
dataset: Current dataset object
data: Update data dictionary
Returns:
bool: True if summary model changed, False otherwise
"""
# Check if summary_index_setting is being updated
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
return False
new_summary_setting = data.get("summary_index_setting")
old_summary_setting = dataset.summary_index_setting
# If new setting is disabled, no need to regenerate
if not new_summary_setting or not new_summary_setting.get("enable"):
return False
# If old setting doesn't exist, no need to regenerate (no existing summaries to regenerate)
# Note: This task only regenerates existing summaries, not generates new ones
if not old_summary_setting:
return False
# If old setting was disabled, no need to regenerate (no existing summaries to regenerate)
if not old_summary_setting.get("enable"):
return False
# Compare model_name and model_provider_name
old_model_name = old_summary_setting.get("model_name")
old_model_provider = old_summary_setting.get("model_provider_name")
new_model_name = new_summary_setting.get("model_name")
new_model_provider = new_summary_setting.get("model_provider_name")
# Check if model changed
if old_model_name != new_model_name or old_model_provider != new_model_provider:
logger.info(
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
dataset.id,
old_model_provider,
old_model_name,
new_model_provider,
new_model_name,
)
return True
return False
@staticmethod
def update_rag_pipeline_dataset_settings(
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
@ -889,6 +966,9 @@ class DatasetService:
else:
raise ValueError("Invalid index method")
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
else:
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
@ -994,6 +1074,9 @@ class DatasetService:
if dataset.keyword_number != knowledge_configuration.keyword_number:
dataset.keyword_number = knowledge_configuration.keyword_number
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
# Update summary_index_setting if provided
if knowledge_configuration.summary_index_setting is not None:
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
session.add(dataset)
session.commit()
if action:
@ -1964,6 +2047,8 @@ class DocumentService:
DuplicateDocumentIndexingTaskProxy(
dataset.tenant_id, dataset.id, duplicate_document_ids
).delay()
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
# to ensure segments are available. See tasks/document_indexing_task.py
except LockNotOwnedError:
pass
@ -2268,6 +2353,11 @@ class DocumentService:
name: str,
batch: str,
):
# Set need_summary based on dataset's summary_index_setting
need_summary = False
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
need_summary = True
document = Document(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
@ -2281,6 +2371,7 @@ class DocumentService:
created_by=account.id,
doc_form=document_form,
doc_language=document_language,
need_summary=need_summary,
)
doc_metadata = {}
if dataset.built_in_field_enabled:
@ -2505,6 +2596,7 @@ class DocumentService:
embedding_model_provider=knowledge_config.embedding_model_provider,
collection_binding_id=dataset_collection_binding_id,
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
summary_index_setting=knowledge_config.summary_index_setting,
is_multimodal=knowledge_config.is_multimodal,
)
@ -2686,6 +2778,14 @@ class DocumentService:
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
# valid summary index setting
summary_index_setting = args["process_rule"].get("summary_index_setting")
if summary_index_setting and summary_index_setting.get("enable"):
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
raise ValueError("Summary index model name is required")
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
raise ValueError("Summary index model provider name is required")
@staticmethod
def batch_update_document_status(
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
@ -3154,6 +3254,35 @@ class SegmentService:
if args.enabled or keyword_changed:
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# update summary index if summary is provided and has changed
if args.summary is not None:
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
if dataset.indexing_technique == "high_quality":
# Query existing summary from database
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
)
# Check if summary has changed
existing_summary_content = existing_summary.summary_content if existing_summary else None
if existing_summary_content != args.summary:
# Summary has changed, update it
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
except Exception:
logger.exception("Failed to update summary for segment %s", segment.id)
# Don't fail the entire update if summary update fails
else:
segment_hash = helper.generate_text_hash(content)
tokens = 0
@ -3228,6 +3357,77 @@ class SegmentService:
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
# update segment vector index
VectorService.update_segment_vector(args.keywords, segment, dataset)
# Handle summary index when content changed
if dataset.indexing_technique == "high_quality":
from models.dataset import DocumentSegmentSummary
existing_summary = (
db.session.query(DocumentSegmentSummary)
.where(
DocumentSegmentSummary.chunk_id == segment.id,
DocumentSegmentSummary.dataset_id == dataset.id,
)
.first()
)
if args.summary is None:
# User didn't provide summary, auto-regenerate if segment previously had summary
# Auto-regeneration only happens if summary_index_setting exists and enable is True
if (
existing_summary
and dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
):
# Segment previously had summary, regenerate it with new content
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, dataset.summary_index_setting
)
logger.info(
"Auto-regenerated summary for segment %s after content change", segment.id
)
except Exception:
logger.exception("Failed to auto-regenerate summary for segment %s", segment.id)
# Don't fail the entire update if summary regeneration fails
else:
# User provided summary, check if it has changed
# Manual summary updates are allowed even if summary_index_setting doesn't exist
existing_summary_content = existing_summary.summary_content if existing_summary else None
if existing_summary_content != args.summary:
# Summary has changed, use user-provided summary
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
logger.info(
"Updated summary for segment %s with user-provided content", segment.id
)
except Exception:
logger.exception("Failed to update summary for segment %s", segment.id)
# Don't fail the entire update if summary update fails
else:
# Summary hasn't changed, regenerate based on new content
# Auto-regeneration only happens if summary_index_setting exists and enable is True
if (
existing_summary
and dataset.summary_index_setting
and dataset.summary_index_setting.get("enable") is True
):
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, dataset.summary_index_setting
)
logger.info(
"Regenerated summary for segment %s after content change (summary unchanged)",
segment.id,
)
except Exception:
logger.exception("Failed to regenerate summary for segment %s", segment.id)
# Don't fail the entire update if summary regeneration fails
# update multimodel vector index
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
except Exception as e:

View File

@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
data_source: DataSource | None = None
process_rule: ProcessRule | None = None
retrieval_model: RetrievalModel | None = None
summary_index_setting: dict | None = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: str | None = None
@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
regenerate_child_chunks: bool = False
enabled: bool | None = None
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class ChildChunkUpdateArgs(BaseModel):

View File

@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
embedding_model: str = ""
keyword_number: int | None = 10
retrieval_model: RetrievalSetting
# add summary index setting
summary_index_setting: dict | None = None
@field_validator("embedding_model_provider", mode="before")
@classmethod

View File

@ -0,0 +1,826 @@
"""Summary index service for generating and managing document segment summaries."""
import logging
import time
import uuid
from datetime import UTC, datetime
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import Document
from extensions.ext_database import db
from libs import helper
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
logger = logging.getLogger(__name__)
class SummaryIndexService:
"""Service for generating and managing summary indexes."""
@staticmethod
def generate_summary_for_segment(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
) -> tuple[str, LLMUsage]:
"""
Generate summary for a single segment.
Args:
segment: DocumentSegment to generate summary for
dataset: Dataset containing the segment
summary_index_setting: Summary index configuration
Returns:
Tuple of (summary_content, llm_usage) where llm_usage is LLMUsage object
Raises:
ValueError: If summary_index_setting is invalid or generation fails
"""
# Reuse the existing generate_summary method from ParagraphIndexProcessor
# Use lazy import to avoid circular import
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
summary_content, usage = ParagraphIndexProcessor.generate_summary(
tenant_id=dataset.tenant_id,
text=segment.content,
summary_index_setting=summary_index_setting,
segment_id=segment.id,
)
if not summary_content:
raise ValueError("Generated summary is empty")
return summary_content, usage
@staticmethod
def create_summary_record(
segment: DocumentSegment,
dataset: Dataset,
summary_content: str,
status: str = "generating",
) -> DocumentSegmentSummary:
"""
Create or update a DocumentSegmentSummary record.
If a summary record already exists for this segment, it will be updated instead of creating a new one.
Args:
segment: DocumentSegment to create summary for
dataset: Dataset containing the segment
summary_content: Generated summary content
status: Summary status (default: "generating")
Returns:
Created or updated DocumentSegmentSummary instance
"""
# Check if summary record already exists
existing_summary = (
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if existing_summary:
# Update existing record
existing_summary.summary_content = summary_content
existing_summary.status = status
existing_summary.error = None # Clear any previous errors
# Re-enable if it was disabled
if not existing_summary.enabled:
existing_summary.enabled = True
existing_summary.disabled_at = None
existing_summary.disabled_by = None
db.session.add(existing_summary)
db.session.flush()
return existing_summary
else:
# Create new record (enabled by default)
summary_record = DocumentSegmentSummary(
dataset_id=dataset.id,
document_id=segment.document_id,
chunk_id=segment.id,
summary_content=summary_content,
status=status,
enabled=True, # Explicitly set enabled to True
)
db.session.add(summary_record)
db.session.flush()
return summary_record
@staticmethod
def vectorize_summary(
summary_record: DocumentSegmentSummary,
segment: DocumentSegment,
dataset: Dataset,
) -> None:
"""
Vectorize summary and store in vector database.
Args:
summary_record: DocumentSegmentSummary record
segment: Original DocumentSegment
dataset: Dataset containing the segment
"""
if dataset.indexing_technique != "high_quality":
logger.warning(
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
dataset.id,
)
return
# Reuse existing index_node_id if available (like segment does), otherwise generate new one
old_summary_node_id = summary_record.summary_index_node_id
if old_summary_node_id:
# Reuse existing index_node_id (like segment behavior)
summary_index_node_id = old_summary_node_id
else:
# Generate new index node ID only for new summaries
summary_index_node_id = str(uuid.uuid4())
# Always regenerate hash (in case summary content changed)
summary_hash = helper.generate_text_hash(summary_record.summary_content)
# Delete old vector only if we're reusing the same index_node_id (to overwrite)
# If index_node_id changed, the old vector should have been deleted elsewhere
if old_summary_node_id and old_summary_node_id == summary_index_node_id:
try:
vector = Vector(dataset)
vector.delete_by_ids([old_summary_node_id])
except Exception as e:
logger.warning(
"Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.",
segment.id,
str(e),
)
# Calculate embedding tokens for summary (for logging and statistics)
embedding_tokens = 0
try:
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens([summary_record.summary_content])
embedding_tokens = tokens_list[0] if tokens_list else 0
except Exception as e:
logger.warning("Failed to calculate embedding tokens for summary: %s", str(e))
# Create document with summary content and metadata
summary_document = Document(
page_content=summary_record.summary_content,
metadata={
"doc_id": summary_index_node_id,
"doc_hash": summary_hash,
"dataset_id": dataset.id,
"document_id": segment.document_id,
"original_chunk_id": segment.id, # Key: link to original chunk
"doc_type": DocType.TEXT,
"is_summary": True, # Identifier for summary documents
},
)
# Vectorize and store with retry mechanism for connection errors
max_retries = 3
retry_delay = 2.0
for attempt in range(max_retries):
try:
vector = Vector(dataset)
# Use duplicate_check=False to ensure re-vectorization even if old vector still exists
# The old vector should have been deleted above, but if deletion failed,
# we still want to re-vectorize (upsert will overwrite)
vector.add_texts([summary_document], duplicate_check=False)
# Log embedding token usage
if embedding_tokens > 0:
logger.info(
"Summary embedding for segment %s used %s tokens",
segment.id,
embedding_tokens,
)
# Success - update summary record with index node info
summary_record.summary_index_node_id = summary_index_node_id
summary_record.summary_index_node_hash = summary_hash
summary_record.tokens = embedding_tokens # Save embedding tokens
summary_record.status = "completed"
# Explicitly update updated_at to ensure it's refreshed even if other fields haven't changed
summary_record.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(summary_record)
db.session.flush()
# Success, exit function
return
except (ConnectionError, Exception) as e:
error_str = str(e).lower()
# Check if it's a connection-related error that might be transient
is_connection_error = any(
keyword in error_str
for keyword in [
"connection",
"disconnected",
"timeout",
"network",
"could not connect",
"server disconnected",
"weaviate",
]
)
if is_connection_error and attempt < max_retries - 1:
# Retry for connection errors
wait_time = retry_delay * (2**attempt) # Exponential backoff
logger.warning(
"Vectorization attempt %s/%s failed for segment %s: %s. Retrying in %.1f seconds...",
attempt + 1,
max_retries,
segment.id,
str(e),
wait_time,
)
time.sleep(wait_time)
continue
else:
# Final attempt failed or non-connection error - log and update status
logger.error(
"Failed to vectorize summary for segment %s after %s attempts: %s",
segment.id,
attempt + 1,
str(e),
exc_info=True,
)
summary_record.status = "error"
summary_record.error = f"Vectorization failed: {str(e)}"
# Explicitly update updated_at to ensure it's refreshed
summary_record.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.add(summary_record)
db.session.flush()
raise
@staticmethod
def batch_create_summary_records(
segments: list[DocumentSegment],
dataset: Dataset,
status: str = "not_started",
) -> None:
"""
Batch create summary records for segments with specified status.
If a record already exists, update its status.
Args:
segments: List of DocumentSegment instances
dataset: Dataset containing the segments
status: Initial status for the records (default: "not_started")
"""
segment_ids = [segment.id for segment in segments]
if not segment_ids:
return
# Query existing summary records
existing_summaries = (
db.session.query(DocumentSegmentSummary)
.filter(
DocumentSegmentSummary.chunk_id.in_(segment_ids),
DocumentSegmentSummary.dataset_id == dataset.id,
)
.all()
)
existing_summary_map = {summary.chunk_id: summary for summary in existing_summaries}
# Create or update records
for segment in segments:
existing_summary = existing_summary_map.get(segment.id)
if existing_summary:
# Update existing record
existing_summary.status = status
existing_summary.error = None # Clear any previous errors
if not existing_summary.enabled:
existing_summary.enabled = True
existing_summary.disabled_at = None
existing_summary.disabled_by = None
db.session.add(existing_summary)
else:
# Create new record
summary_record = DocumentSegmentSummary(
dataset_id=dataset.id,
document_id=segment.document_id,
chunk_id=segment.id,
summary_content=None, # Will be filled later
status=status,
enabled=True,
)
db.session.add(summary_record)
@staticmethod
def update_summary_record_error(
segment: DocumentSegment,
dataset: Dataset,
error: str,
) -> None:
"""
Update summary record with error status.
Args:
segment: DocumentSegment
dataset: Dataset containing the segment
error: Error message
"""
summary_record = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
)
if summary_record:
summary_record.status = "error"
summary_record.error = error
db.session.add(summary_record)
db.session.flush()
else:
logger.warning(
"Summary record not found for segment %s when updating error", segment.id
)
@staticmethod
def generate_and_vectorize_summary(
segment: DocumentSegment,
dataset: Dataset,
summary_index_setting: dict,
) -> DocumentSegmentSummary:
"""
Generate summary for a segment and vectorize it.
Assumes summary record already exists (created by batch_create_summary_records).
Args:
segment: DocumentSegment to generate summary for
dataset: Dataset containing the segment
summary_index_setting: Summary index configuration
Returns:
Created DocumentSegmentSummary instance
Raises:
ValueError: If summary generation fails
"""
# Get existing summary record (should have been created by batch_create_summary_records)
summary_record = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
)
if not summary_record:
# If not found (shouldn't happen), create one
logger.warning(
"Summary record not found for segment %s, creating one", segment.id
)
summary_record = SummaryIndexService.create_summary_record(
segment, dataset, summary_content="", status="generating"
)
try:
# Update status to "generating"
summary_record.status = "generating"
summary_record.error = None
db.session.add(summary_record)
db.session.flush()
# Generate summary (returns summary_content and llm_usage)
summary_content, llm_usage = SummaryIndexService.generate_summary_for_segment(
segment, dataset, summary_index_setting
)
# Update summary content
summary_record.summary_content = summary_content
# Log LLM usage for summary generation
if llm_usage and llm_usage.total_tokens > 0:
logger.info(
"Summary generation for segment %s used %s tokens (prompt: %s, completion: %s)",
segment.id,
llm_usage.total_tokens,
llm_usage.prompt_tokens,
llm_usage.completion_tokens,
)
# Vectorize summary (will delete old vector if exists before creating new one)
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
# Status will be updated to "completed" by vectorize_summary on success
db.session.commit()
logger.info("Successfully generated and vectorized summary for segment %s", segment.id)
return summary_record
except Exception as e:
logger.exception("Failed to generate summary for segment %s", segment.id)
# Update summary record with error status
summary_record.status = "error"
summary_record.error = str(e)
db.session.add(summary_record)
db.session.commit()
raise
@staticmethod
def generate_summaries_for_document(
dataset: Dataset,
document: DatasetDocument,
summary_index_setting: dict,
segment_ids: list[str] | None = None,
only_parent_chunks: bool = False,
) -> list[DocumentSegmentSummary]:
"""
Generate summaries for all segments in a document including vectorization.
Args:
dataset: Dataset containing the document
document: DatasetDocument to generate summaries for
summary_index_setting: Summary index configuration
segment_ids: Optional list of specific segment IDs to process
only_parent_chunks: If True, only process parent chunks (for parent-child mode)
Returns:
List of created DocumentSegmentSummary instances
"""
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
logger.info(
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
dataset.id,
dataset.indexing_technique,
)
return []
if not summary_index_setting or not summary_index_setting.get("enable"):
logger.info("Summary index is disabled for dataset %s", dataset.id)
return []
# Skip qa_model documents
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
return []
logger.info(
"Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s",
document.id,
dataset.id,
len(segment_ids) if segment_ids else "all",
only_parent_chunks,
)
# Query segments (only enabled segments)
query = db.session.query(DocumentSegment).filter_by(
dataset_id=dataset.id,
document_id=document.id,
status="completed",
enabled=True, # Only generate summaries for enabled segments
)
if segment_ids:
query = query.filter(DocumentSegment.id.in_(segment_ids))
segments = query.all()
if not segments:
logger.info("No segments found for document %s", document.id)
return []
# Batch create summary records with "not_started" status before processing
# This ensures all records exist upfront, allowing status tracking
SummaryIndexService.batch_create_summary_records(
segments=segments,
dataset=dataset,
status="not_started",
)
db.session.commit() # Commit initial records
summary_records = []
for segment in segments:
# For parent-child mode, only process parent chunks
# In parent-child mode, all DocumentSegments are parent chunks,
# so we process all of them. Child chunks are stored in ChildChunk table
# and are not DocumentSegments, so they won't be in the segments list.
# This check is mainly for clarity and future-proofing.
if only_parent_chunks:
# In parent-child mode, all segments in the query are parent chunks
# Child chunks are not DocumentSegments, so they won't appear here
# We can process all segments
pass
try:
summary_record = SummaryIndexService.generate_and_vectorize_summary(
segment, dataset, summary_index_setting
)
summary_records.append(summary_record)
except Exception as e:
logger.exception("Failed to generate summary for segment %s", segment.id)
# Update summary record with error status
SummaryIndexService.update_summary_record_error(
segment=segment,
dataset=dataset,
error=str(e),
)
# Continue with other segments
continue
db.session.commit() # Commit any remaining changes
logger.info(
"Completed summary generation for document %s: %s summaries generated and vectorized",
document.id,
len(summary_records),
)
return summary_records
@staticmethod
def disable_summaries_for_segments(
dataset: Dataset,
segment_ids: list[str] | None = None,
disabled_by: str | None = None,
) -> None:
"""
Disable summary records and remove vectors from vector database for segments.
Unlike delete, this preserves the summary records but marks them as disabled.
Args:
dataset: Dataset containing the segments
segment_ids: List of segment IDs to disable summaries for. If None, disable all.
disabled_by: User ID who disabled the summaries
"""
from libs.datetime_utils import naive_utc_now
query = db.session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=True, # Only disable enabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
if not summaries:
return
logger.info(
"Disabling %s summary records for dataset %s, segment_ids: %s",
len(summaries),
dataset.id,
len(segment_ids) if segment_ids else "all",
)
# Remove from vector database (but keep records)
if dataset.indexing_technique == "high_quality":
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
try:
vector = Vector(dataset)
vector.delete_by_ids(summary_node_ids)
except Exception as e:
logger.warning("Failed to remove summary vectors: %s", str(e))
# Disable summary records (don't delete)
now = naive_utc_now()
for summary in summaries:
summary.enabled = False
summary.disabled_at = now
summary.disabled_by = disabled_by
db.session.add(summary)
db.session.commit()
logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id)
@staticmethod
def enable_summaries_for_segments(
dataset: Dataset,
segment_ids: list[str] | None = None,
) -> None:
"""
Enable summary records and re-add vectors to vector database for segments.
Note: This method enables summaries based on chunk status, not summary_index_setting.enable.
The summary_index_setting.enable flag only controls automatic generation,
not whether existing summaries can be used.
Summary.enabled should always be kept in sync with chunk.enabled.
Args:
dataset: Dataset containing the segments
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
"""
# Only enable summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return
query = db.session.query(DocumentSegmentSummary).filter_by(
dataset_id=dataset.id,
enabled=False, # Only enable disabled summaries
)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
if not summaries:
return
logger.info(
"Enabling %s summary records for dataset %s, segment_ids: %s",
len(summaries),
dataset.id,
len(segment_ids) if segment_ids else "all",
)
# Re-vectorize and re-add to vector database
enabled_count = 0
for summary in summaries:
# Get the original segment
segment = (
db.session.query(DocumentSegment)
.filter_by(
id=summary.chunk_id,
dataset_id=dataset.id,
)
.first()
)
# Summary.enabled stays in sync with chunk.enabled, only enable summary if the associated chunk is enabled.
if not segment or not segment.enabled or segment.status != "completed":
continue
if not summary.summary_content:
continue
try:
# Re-vectorize summary
SummaryIndexService.vectorize_summary(summary, segment, dataset)
# Enable summary record
summary.enabled = True
summary.disabled_at = None
summary.disabled_by = None
db.session.add(summary)
enabled_count += 1
except Exception:
logger.exception("Failed to re-vectorize summary %s", summary.id)
# Keep it disabled if vectorization fails
continue
db.session.commit()
logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id)
@staticmethod
def delete_summaries_for_segments(
dataset: Dataset,
segment_ids: list[str] | None = None,
) -> None:
"""
Delete summary records and vectors for segments (used only for actual deletion scenarios).
For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments.
Args:
dataset: Dataset containing the segments
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
"""
query = db.session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
if segment_ids:
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
summaries = query.all()
if not summaries:
return
# Delete from vector database
if dataset.indexing_technique == "high_quality":
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
if summary_node_ids:
vector = Vector(dataset)
vector.delete_by_ids(summary_node_ids)
# Delete summary records
for summary in summaries:
db.session.delete(summary)
db.session.commit()
logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id)
@staticmethod
def update_summary_for_segment(
segment: DocumentSegment,
dataset: Dataset,
summary_content: str,
) -> DocumentSegmentSummary | None:
"""
Update summary for a segment and re-vectorize it.
Args:
segment: DocumentSegment to update summary for
dataset: Dataset containing the segment
summary_content: New summary content
Returns:
Updated DocumentSegmentSummary instance, or None if indexing technique is not high_quality
"""
# Only update summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
return None
# When user manually provides summary, allow saving even if summary_index_setting doesn't exist
# summary_index_setting is only needed for LLM generation, not for manual summary vectorization
# Vectorization uses dataset.embedding_model, which doesn't require summary_index_setting
# Skip qa_model documents
if segment.document and segment.document.doc_form == "qa_model":
return None
try:
# Check if summary_content is empty (whitespace-only strings are considered empty)
if not summary_content or not summary_content.strip():
# If summary is empty, only delete existing summary vector and record
summary_record = (
db.session.query(DocumentSegmentSummary)
.filter_by(chunk_id=segment.id, dataset_id=dataset.id)
.first()
)
if summary_record:
# Delete old vector if exists
old_summary_node_id = summary_record.summary_index_node_id
if old_summary_node_id:
try:
vector = Vector(dataset)
vector.delete_by_ids([old_summary_node_id])
except Exception as e:
logger.warning(
"Failed to delete old summary vector for segment %s: %s",
segment.id,
str(e),
)
# Delete summary record since summary is empty
db.session.delete(summary_record)
db.session.commit()
logger.info("Deleted summary for segment %s (empty content provided)", segment.id)
return None
else:
# No existing summary record, nothing to do
logger.info("No summary record found for segment %s, nothing to delete", segment.id)
return None
# Find existing summary record
summary_record = (
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record:
# Update existing summary
old_summary_node_id = summary_record.summary_index_node_id
# Update summary content
summary_record.summary_content = summary_content
summary_record.status = "generating"
db.session.add(summary_record)
db.session.flush()
# Delete old vector if exists
if old_summary_node_id:
vector = Vector(dataset)
vector.delete_by_ids([old_summary_node_id])
# Re-vectorize summary
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
db.session.commit()
logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id)
return summary_record
else:
# Create new summary record if doesn't exist
summary_record = SummaryIndexService.create_summary_record(
segment, dataset, summary_content, status="generating"
)
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
db.session.commit()
logger.info("Successfully created and vectorized summary for segment %s", segment.id)
return summary_record
except Exception:
logger.exception("Failed to update summary for segment %s", segment.id)
# Update summary record with error status if it exists
summary_record = (
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
)
if summary_record:
summary_record.status = "error"
summary_record.error = str(e)
db.session.add(summary_record)
db.session.commit()
raise

View File

@ -118,6 +118,19 @@ def add_document_to_index_task(dataset_document_id: str):
)
session.commit()
# Enable summary indexes for all segments in this document
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
if segment_ids_list:
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
)
except Exception as e:
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")

View File

@ -50,7 +50,9 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -51,7 +51,9 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)

View File

@ -42,7 +42,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)

View File

@ -47,6 +47,7 @@ def delete_segment_from_index_task(
doc_form = dataset_document.doc_form
# Proceed with index cleanup using the index_node_ids directly
# For actual deletion, we should delete summaries (not just disable them)
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
@ -54,6 +55,7 @@ def delete_segment_from_index_task(
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
delete_summaries=True, # Actually delete summaries when segment is deleted
)
if dataset.is_multimodal:
# delete segment attachment binding

View File

@ -60,6 +60,18 @@ def disable_segment_from_index_task(segment_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
# Disable summary index for this segment
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.disable_summaries_for_segments(
dataset=dataset,
segment_ids=[segment.id],
disabled_by=segment.disabled_by,
)
except Exception as e:
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
end_at = time.perf_counter()
logger.info(
click.style(

View File

@ -68,6 +68,21 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
# Disable summary indexes for these segments
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
try:
# Get disabled_by from first segment (they should all have the same disabled_by)
disabled_by = segments[0].disabled_by if segments else None
SummaryIndexService.disable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
disabled_by=disabled_by,
)
except Exception as e:
logger.warning("Failed to disable summaries for segments: %s", str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:

View File

@ -14,6 +14,7 @@ from enums.cloud_plan import CloudPlan
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
from tasks.generate_summary_index_task import generate_summary_index_task
logger = logging.getLogger(__name__)
@ -99,6 +100,71 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset %s not found after indexing", dataset_id)
return
if dataset.indexing_technique == "high_quality":
summary_index_setting = dataset.summary_index_setting
if summary_index_setting and summary_index_setting.get("enable"):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s",
document_id,
document.indexing_status,
document.doc_form,
)
if document.indexing_status == "completed" and document.doc_form != "qa_model":
try:
generate_summary_index_task.delay(dataset.id, document_id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document_id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document_id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: status=%s, doc_form=%s",
document_id,
document.indexing_status,
document.doc_form,
)
else:
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:

View File

@ -106,6 +106,17 @@ def enable_segment_to_index_task(segment_id: str):
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
# Enable summary index for this segment
from services.summary_index_service import SummaryIndexService
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=[segment.id],
)
except Exception as e:
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:

View File

@ -106,6 +106,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# Enable summary indexes for these segments
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
try:
SummaryIndexService.enable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
)
except Exception as e:
logger.warning("Failed to enable summaries for segments: %s", str(e))
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:

View File

@ -0,0 +1,112 @@
"""Async task for generating summary indexes."""
import logging
import time
import click
from celery import shared_task
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
"""
Async generate summary index for document segments.
Args:
dataset_id: Dataset ID
document_id: Document ID
segment_ids: Optional list of specific segment IDs to process. If None, process all segments.
Usage:
generate_summary_index_task.delay(dataset_id, document_id)
generate_summary_index_task.delay(dataset_id, document_id, segment_ids)
"""
logger.info(
click.style(
f"Start generating summary index for document {document_id} in dataset {dataset_id}",
fg="green",
)
)
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not document:
logger.error(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
# Only generate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
logger.info(
click.style(
f"Skipping summary generation for dataset {dataset_id}: "
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
fg="cyan",
)
)
db.session.close()
return
# Check if summary index is enabled
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
logger.info(
click.style(
f"Summary index is disabled for dataset {dataset_id}",
fg="cyan",
)
)
db.session.close()
return
# Determine if only parent chunks should be processed
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
# Generate summaries
summary_records = SummaryIndexService.generate_summaries_for_document(
dataset=dataset,
document=document,
summary_index_setting=summary_index_setting,
segment_ids=segment_ids,
only_parent_chunks=only_parent_chunks,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Summary index generation completed for document {document_id}: "
f"{len(summary_records)} summaries generated, latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Failed to generate summary index for document %s", document_id)
# Update document segments with error status if needed
if segment_ids:
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
).update(
{
DocumentSegment.error: f"Summary generation failed: {str(e)}",
},
synchronize_session=False,
)
db.session.commit()
finally:
db.session.close()

View File

@ -0,0 +1,318 @@
"""Task for regenerating summary indexes when dataset settings change."""
import logging
import time
from collections import defaultdict
import click
from celery import shared_task
from sqlalchemy import or_, select
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
from models.dataset import Document as DatasetDocument
from services.summary_index_service import SummaryIndexService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def regenerate_summary_index_task(
dataset_id: str,
regenerate_reason: str = "summary_model_changed",
regenerate_vectors_only: bool = False,
):
"""
Regenerate summary indexes for all documents in a dataset.
This task is triggered when:
1. summary_index_setting model changes (regenerate_reason="summary_model_changed")
- Regenerates summary content and vectors for all existing summaries
2. embedding_model changes (regenerate_reason="embedding_model_changed")
- Only regenerates vectors for existing summaries (keeps summary content)
Args:
dataset_id: Dataset ID
regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed")
regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content
"""
logger.info(
click.style(
f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}",
fg="green",
)
)
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# Only regenerate summary index for high_quality indexing technique
if dataset.indexing_technique != "high_quality":
logger.info(
click.style(
f"Skipping summary regeneration for dataset {dataset_id}: "
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
fg="cyan",
)
)
db.session.close()
return
# Check if summary index is enabled (only for summary_model change)
# For embedding_model change, we still re-vectorize existing summaries even if setting is disabled
summary_index_setting = dataset.summary_index_setting
if not regenerate_vectors_only:
# For summary_model change, require summary_index_setting to be enabled
if not summary_index_setting or not summary_index_setting.get("enable"):
logger.info(
click.style(
f"Summary index is disabled for dataset {dataset_id}",
fg="cyan",
)
)
db.session.close()
return
total_segments_processed = 0
total_segments_failed = 0
if regenerate_vectors_only:
# For embedding_model change: directly query all segments with existing summaries
# Don't require document indexing_status == "completed"
# Include summaries with status "completed" or "error" (if they have content)
segments_with_summaries = (
db.session.query(DocumentSegment, DocumentSegmentSummary)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
)
.join(
DatasetDocument,
DocumentSegment.document_id == DatasetDocument.id,
)
.where(
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed", # Segment must be completed
DocumentSegment.enabled == True,
DocumentSegmentSummary.dataset_id == dataset_id,
DocumentSegmentSummary.summary_content.isnot(None), # Must have summary content
# Include completed summaries or error summaries (with content)
or_(
DocumentSegmentSummary.status == "completed",
DocumentSegmentSummary.status == "error",
),
DatasetDocument.enabled == True, # Document must be enabled
DatasetDocument.archived == False, # Document must not be archived
DatasetDocument.doc_form != "qa_model", # Skip qa_model documents
)
.order_by(DocumentSegment.document_id.asc(), DocumentSegment.position.asc())
.all()
)
if not segments_with_summaries:
logger.info(
click.style(
f"No segments with summaries found for re-vectorization in dataset {dataset_id}",
fg="cyan",
)
)
db.session.close()
return
logger.info(
"Found %s segments with summaries for re-vectorization in dataset %s",
len(segments_with_summaries),
dataset_id,
)
# Group by document for logging
segments_by_document = defaultdict(list)
for segment, summary_record in segments_with_summaries:
segments_by_document[segment.document_id].append((segment, summary_record))
logger.info(
"Segments grouped into %s documents for re-vectorization",
len(segments_by_document),
)
for document_id, segment_summary_pairs in segments_by_document.items():
logger.info(
"Re-vectorizing summaries for %s segments in document %s",
len(segment_summary_pairs),
document_id,
)
for segment, summary_record in segment_summary_pairs:
try:
# Delete old vector
if summary_record.summary_index_node_id:
try:
from core.rag.datasource.vdb.vector_factory import Vector
vector = Vector(dataset)
vector.delete_by_ids([summary_record.summary_index_node_id])
except Exception as e:
logger.warning(
"Failed to delete old summary vector for segment %s: %s",
segment.id,
str(e),
)
# Re-vectorize with new embedding model
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
db.session.commit()
total_segments_processed += 1
except Exception as e:
logger.error(
"Failed to re-vectorize summary for segment %s: %s",
segment.id,
str(e),
exc_info=True,
)
total_segments_failed += 1
# Update summary record with error status
summary_record.status = "error"
summary_record.error = f"Re-vectorization failed: {str(e)}"
db.session.add(summary_record)
db.session.commit()
continue
else:
# For summary_model change: require document indexing_status == "completed"
# Get all documents with completed indexing status
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if not dataset_documents:
logger.info(
click.style(
f"No documents found for summary regeneration in dataset {dataset_id}",
fg="cyan",
)
)
db.session.close()
return
logger.info(
"Found %s documents for summary regeneration in dataset %s",
len(dataset_documents),
dataset_id,
)
for dataset_document in dataset_documents:
# Skip qa_model documents
if dataset_document.doc_form == "qa_model":
continue
try:
# Get all segments with existing summaries
segments = (
db.session.query(DocumentSegment)
.join(
DocumentSegmentSummary,
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegmentSummary.dataset_id == dataset_id,
)
.order_by(DocumentSegment.position.asc())
.all()
)
if not segments:
continue
logger.info(
"Regenerating summaries for %s segments in document %s",
len(segments),
dataset_document.id,
)
for segment in segments:
try:
# Get existing summary record
summary_record = (
db.session.query(DocumentSegmentSummary)
.filter_by(
chunk_id=segment.id,
dataset_id=dataset_id,
)
.first()
)
if not summary_record:
logger.warning("Summary record not found for segment %s, skipping", segment.id)
continue
# Regenerate both summary content and vectors (for summary_model change)
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
db.session.commit()
total_segments_processed += 1
except Exception as e:
logger.error(
"Failed to regenerate summary for segment %s: %s",
segment.id,
str(e),
exc_info=True,
)
total_segments_failed += 1
# Update summary record with error status
if summary_record:
summary_record.status = "error"
summary_record.error = f"Regeneration failed: {str(e)}"
db.session.add(summary_record)
db.session.commit()
continue
except Exception as e:
logger.error(
"Failed to process document %s for summary regeneration: %s",
dataset_document.id,
str(e),
exc_info=True,
)
continue
end_at = time.perf_counter()
if regenerate_vectors_only:
logger.info(
click.style(
f"Summary re-vectorization completed for dataset {dataset_id}: "
f"{total_segments_processed} segments processed successfully, "
f"{total_segments_failed} segments failed, "
f"latency: {end_at - start_at:.2f}s",
fg="green",
)
)
else:
logger.info(
click.style(
f"Summary index regeneration completed for dataset {dataset_id}: "
f"{total_segments_processed} segments processed successfully, "
f"{total_segments_failed} segments failed, "
f"latency: {end_at - start_at:.2f}s",
fg="green",
)
)
except Exception:
logger.exception("Regenerate summary index failed for dataset %s", dataset_id)
finally:
db.session.close()

View File

@ -46,6 +46,21 @@ def remove_document_from_index_task(document_id: str):
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
# Disable summary indexes for all segments in this document
from services.summary_index_service import SummaryIndexService
segment_ids_list = [segment.id for segment in segments]
if segment_ids_list:
try:
SummaryIndexService.disable_summaries_for_segments(
dataset=dataset,
segment_ids=segment_ids_list,
disabled_by=document.disabled_by,
)
except Exception as e:
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:

View File

@ -0,0 +1,454 @@
"""Test multimodal image output handling in BaseAppRunner."""
from unittest.mock import MagicMock, patch
from uuid import uuid4
import pytest
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueMessageFileEvent
from core.file.enums import FileTransferMethod, FileType
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from models.enums import CreatorUserRole
class TestBaseAppRunnerMultimodal:
"""Test that BaseAppRunner correctly handles multimodal image content."""
@pytest.fixture
def mock_user_id(self):
"""Mock user ID."""
return str(uuid4())
@pytest.fixture
def mock_tenant_id(self):
"""Mock tenant ID."""
return str(uuid4())
@pytest.fixture
def mock_message_id(self):
"""Mock message ID."""
return str(uuid4())
@pytest.fixture
def mock_queue_manager(self):
"""Create a mock queue manager."""
manager = MagicMock()
manager.invoke_from = InvokeFrom.SERVICE_API
return manager
@pytest.fixture
def mock_tool_file(self):
"""Create a mock tool file."""
tool_file = MagicMock()
tool_file.id = str(uuid4())
return tool_file
@pytest.fixture
def mock_message_file(self):
"""Create a mock message file."""
message_file = MagicMock()
message_file.id = str(uuid4())
return message_file
def test_handle_multimodal_image_content_with_url(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from URL."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from URL
mock_mgr.create_file_by_url.assert_called_once_with(
user_id=mock_user_id,
tenant_id=mock_tenant_id,
file_url=image_url,
conversation_id=None,
)
# Verify message file was created with correct parameters
mock_msg_file_class.assert_called_once()
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["message_id"] == mock_message_id
assert call_kwargs["type"] == FileType.IMAGE
assert call_kwargs["transfer_method"] == FileTransferMethod.TOOL_FILE
assert call_kwargs["belongs_to"] == "assistant"
assert call_kwargs["created_by"] == mock_user_id
# Verify database operations
mock_session.add.assert_called_once_with(mock_message_file)
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once_with(mock_message_file)
# Verify event was published
mock_queue_manager.publish.assert_called_once()
publish_call = mock_queue_manager.publish.call_args
assert isinstance(publish_call[0][0], QueueMessageFileEvent)
assert publish_call[0][0].message_file_id == mock_message_file.id
# publish_from might be passed as positional or keyword argument
assert (
publish_call[0][1] == PublishFrom.APPLICATION_MANAGER
or publish_call.kwargs.get("publish_from") == PublishFrom.APPLICATION_MANAGER
)
def test_handle_multimodal_image_content_with_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data."""
# Arrange
import base64
# Create a small test image (1x1 PNG)
test_image_data = base64.b64encode(
b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde"
).decode()
content = ImagePromptMessageContent(
base64_data=test_image_data,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert
# Verify tool file was created from base64
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
assert call_kwargs["user_id"] == mock_user_id
assert call_kwargs["tenant_id"] == mock_tenant_id
assert call_kwargs["conversation_id"] is None
assert "file_binary" in call_kwargs
assert call_kwargs["mimetype"] == "image/png"
assert call_kwargs["filename"].startswith("generated_image")
assert call_kwargs["filename"].endswith(".png")
# Verify message file was created
mock_msg_file_class.assert_called_once()
# Verify database operations
mock_session.add.assert_called_once()
mock_session.commit.assert_called_once()
mock_session.refresh.assert_called_once()
# Verify event was published
mock_queue_manager.publish.assert_called_once()
def test_handle_multimodal_image_content_with_base64_data_uri(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test handling image from base64 data with URI prefix."""
# Arrange
# Data URI format: data:image/png;base64,<base64_data>
test_image_data = (
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=="
)
content = ImagePromptMessageContent(
base64_data=f"data:image/png;base64,{test_image_data}",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_raw.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify that base64 data was extracted correctly (without prefix)
mock_mgr.create_file_by_raw.assert_called_once()
call_kwargs = mock_mgr.create_file_by_raw.call_args[1]
# The base64 data should be decoded, so we check the binary was passed
assert "file_binary" in call_kwargs
def test_handle_multimodal_image_content_without_url_or_base64(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content without URL or base64 data."""
# Arrange
content = ImagePromptMessageContent(
url="",
base64_data="",
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create any files or publish events
mock_mgr_class.assert_not_called()
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_with_error(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
):
"""Test handling image content when an error occurs."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock to raise exception
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.side_effect = Exception("Network error")
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
# Should not raise exception, just log it
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - should not create message file or publish event on error
mock_msg_file_class.assert_not_called()
mock_session.add.assert_not_called()
mock_queue_manager.publish.assert_not_called()
def test_handle_multimodal_image_content_debugger_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that debugger mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.DEBUGGER
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is ACCOUNT for debugger mode
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.ACCOUNT
def test_handle_multimodal_image_content_service_api_mode(
self,
mock_user_id,
mock_tenant_id,
mock_message_id,
mock_queue_manager,
mock_tool_file,
mock_message_file,
):
"""Test that service API mode sets correct created_by_role."""
# Arrange
image_url = "http://example.com/image.png"
content = ImagePromptMessageContent(
url=image_url,
format="png",
mime_type="image/png",
)
mock_queue_manager.invoke_from = InvokeFrom.SERVICE_API
with patch("core.app.apps.base_app_runner.ToolFileManager") as mock_mgr_class:
# Setup mock tool file manager
mock_mgr = MagicMock()
mock_mgr.create_file_by_url.return_value = mock_tool_file
mock_mgr_class.return_value = mock_mgr
with patch("core.app.apps.base_app_runner.MessageFile") as mock_msg_file_class:
# Setup mock message file
mock_msg_file_class.return_value = mock_message_file
with patch("core.app.apps.base_app_runner.db.session") as mock_session:
mock_session.add = MagicMock()
mock_session.commit = MagicMock()
mock_session.refresh = MagicMock()
# Act
# Create a mock runner with the method bound
runner = MagicMock()
method = AppRunner._handle_multimodal_image_content
runner._handle_multimodal_image_content = lambda *args, **kwargs: method(runner, *args, **kwargs)
runner._handle_multimodal_image_content(
content=content,
message_id=mock_message_id,
user_id=mock_user_id,
tenant_id=mock_tenant_id,
queue_manager=mock_queue_manager,
)
# Assert - verify created_by_role is END_USER for service API
call_kwargs = mock_msg_file_class.call_args[1]
assert call_kwargs["created_by_role"] == CreatorUserRole.END_USER

View File

@ -1,7 +1,6 @@
"""Unit tests for the message cycle manager optimization."""
from types import SimpleNamespace
from unittest.mock import ANY, Mock, patch
from unittest.mock import Mock, patch
import pytest
from flask import current_app
@ -28,17 +27,14 @@ class TestMessageCycleManagerOptimization:
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = mock_message_file
# Execute
with current_app.app_context():
@ -46,19 +42,16 @@ class TestMessageCycleManagerOptimization:
# Assert
assert result == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and no message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = None
# Execute
with current_app.app_context():
@ -66,21 +59,18 @@ class TestMessageCycleManagerOptimization:
# Assert
assert result == StreamEvent.MESSAGE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_precomputed_event_type(self, message_cycle_manager):
"""MessageCycleManager.message_to_stream_response expects a valid event_type; callers should precompute it."""
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = mock_message_file
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response
with current_app.app_context():
@ -94,11 +84,11 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE_FILE
mock_session.query.return_value.scalar.assert_called_once()
mock_session.scalar.assert_called_once()
def test_message_to_stream_response_with_event_type_skips_query(self, message_cycle_manager):
"""Test that message_to_stream_response skips database query when event_type is provided."""
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Execute with event_type provided
result = message_cycle_manager.message_to_stream_response(
answer="Hello world", message_id="test-message-id", event_type=StreamEvent.MESSAGE
@ -109,8 +99,8 @@ class TestMessageCycleManagerOptimization:
assert result.answer == "Hello world"
assert result.id == "test-message-id"
assert result.event == StreamEvent.MESSAGE
# Should not query database when event_type is provided
mock_session_class.assert_not_called()
# Should not open a session when event_type is provided
mock_session_factory.create_session.assert_not_called()
def test_message_to_stream_response_with_from_variable_selector(self, message_cycle_manager):
"""Test message_to_stream_response with from_variable_selector parameter."""
@ -130,24 +120,21 @@ class TestMessageCycleManagerOptimization:
def test_optimization_usage_example(self, message_cycle_manager):
"""Test the optimization pattern that should be used by callers."""
# Step 1: Get event type once (this queries database)
with (
patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class,
patch("core.app.task_pipeline.message_cycle_manager.db", new=SimpleNamespace(engine=Mock())),
):
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
mock_session = Mock()
mock_session_class.return_value.__enter__.return_value = mock_session
# Current implementation uses session.query(...).scalar()
mock_session.query.return_value.scalar.return_value = None # No files
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# Current implementation uses session.scalar(select(...))
mock_session.scalar.return_value = None # No files
with current_app.app_context():
event_type = message_cycle_manager.get_message_event_type("test-message-id")
# Should query database once
mock_session_class.assert_called_once_with(ANY, expire_on_commit=False)
# Should open session once
mock_session_factory.create_session.assert_called_once()
assert event_type == StreamEvent.MESSAGE
# Step 2: Use event_type for multiple calls (no additional queries)
with patch("core.app.task_pipeline.message_cycle_manager.Session") as mock_session_class:
mock_session_class.return_value.__enter__.return_value = Mock()
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
mock_session_factory.create_session.return_value.__enter__.return_value = Mock()
chunk1_response = message_cycle_manager.message_to_stream_response(
answer="Chunk 1", message_id="test-message-id", event_type=event_type
@ -157,8 +144,8 @@ class TestMessageCycleManagerOptimization:
answer="Chunk 2", message_id="test-message-id", event_type=event_type
)
# Should not query database again
mock_session_class.assert_not_called()
# Should not open session again when event_type provided
mock_session_factory.create_session.assert_not_called()
assert chunk1_response.event == StreamEvent.MESSAGE
assert chunk2_response.event == StreamEvent.MESSAGE

View File

@ -99,3 +99,38 @@ def mock_is_instrument_flag_enabled_true():
"""Mock is_instrument_flag_enabled to return True."""
with patch("core.workflow.graph_engine.layers.observability.is_instrument_flag_enabled", return_value=True):
yield
@pytest.fixture
def mock_retrieval_node():
"""Create a mock Knowledge Retrieval Node."""
node = MagicMock()
node.id = "test-retrieval-node-id"
node.title = "Retrieval Node"
node.execution_id = "test-retrieval-execution-id"
node.node_type = NodeType.KNOWLEDGE_RETRIEVAL
return node
@pytest.fixture
def mock_result_event():
"""Create a mock result event with NodeRunResult."""
from datetime import datetime
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events.base import NodeRunResult
node_run_result = NodeRunResult(
inputs={"query": "test query"},
outputs={"result": [{"content": "test content", "metadata": {}}]},
process_data={},
metadata={},
)
return NodeRunSucceededEvent(
id="test-execution-id",
node_id="test-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=node_run_result,
)

View File

@ -4,7 +4,8 @@ Tests for ObservabilityLayer.
Test coverage:
- Initialization and enable/disable logic
- Node span lifecycle (start, end, error handling)
- Parser integration (default and tool-specific)
- Parser integration (default, tool, LLM, and retrieval parsers)
- Result event parameter extraction (inputs/outputs)
- Graph lifecycle management
- Disabled mode behavior
"""
@ -134,9 +135,101 @@ class TestObservabilityLayerParserIntegration:
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_tool_node.id
assert attrs["tool.provider.id"] == mock_tool_node._node_data.provider_id
assert attrs["tool.provider.type"] == mock_tool_node._node_data.provider_type.value
assert attrs["tool.name"] == mock_tool_node._node_data.tool_name
assert attrs["gen_ai.tool.name"] == mock_tool_node.title
assert attrs["gen_ai.tool.type"] == mock_tool_node._node_data.provider_type.value
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_llm_parser_used_for_llm_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event
):
"""Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes."""
from core.workflow.node_events.base import NodeRunResult
mock_result_event.node_run_result = NodeRunResult(
inputs={},
outputs={"text": "test completion", "finish_reason": "stop"},
process_data={
"model_name": "gpt-4",
"model_provider": "openai",
"usage": {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
"prompts": [{"role": "user", "text": "test prompt"}],
},
metadata={},
)
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_llm_node)
layer.on_node_run_end(mock_llm_node, None, mock_result_event)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_llm_node.id
assert attrs["gen_ai.request.model"] == "gpt-4"
assert attrs["gen_ai.provider.name"] == "openai"
assert attrs["gen_ai.usage.input_tokens"] == 10
assert attrs["gen_ai.usage.output_tokens"] == 20
assert attrs["gen_ai.usage.total_tokens"] == 30
assert attrs["gen_ai.completion"] == "test completion"
assert attrs["gen_ai.response.finish_reason"] == "stop"
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_retrieval_parser_used_for_retrieval_node(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event
):
"""Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes."""
from core.workflow.node_events.base import NodeRunResult
mock_result_event.node_run_result = NodeRunResult(
inputs={"query": "test query"},
outputs={"result": [{"content": "test content", "metadata": {"score": 0.9, "document_id": "doc1"}}]},
process_data={},
metadata={},
)
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_retrieval_node)
layer.on_node_run_end(mock_retrieval_node, None, mock_result_event)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert attrs["node.id"] == mock_retrieval_node.id
assert attrs["retrieval.query"] == "test query"
assert "retrieval.document" in attrs
@patch("core.workflow.graph_engine.layers.observability.dify_config.ENABLE_OTEL", True)
@pytest.mark.usefixtures("mock_is_instrument_flag_enabled_false")
def test_result_event_extracts_inputs_and_outputs(
self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event
):
"""Test that result_event parameter allows parsers to extract inputs and outputs."""
from core.workflow.node_events.base import NodeRunResult
mock_result_event.node_run_result = NodeRunResult(
inputs={"input_key": "input_value"},
outputs={"output_key": "output_value"},
process_data={},
metadata={},
)
layer = ObservabilityLayer()
layer.on_graph_start()
layer.on_node_run_start(mock_start_node)
layer.on_node_run_end(mock_start_node, None, mock_result_event)
spans = memory_span_exporter.get_finished_spans()
assert len(spans) == 1
attrs = spans[0].attributes
assert "input.value" in attrs
assert "output.value" in attrs
class TestObservabilityLayerGraphLifecycle:

View File

@ -171,22 +171,26 @@ class TestBillingServiceSendRequest:
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_delete_request_non_200_with_valid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test DELETE request with non-200 status code but valid JSON response.
"""Test DELETE request with non-200 status code raises ValueError.
DELETE doesn't check status code, so it returns the error JSON.
DELETE now checks status code and raises ValueError for non-200 responses.
"""
# Arrange
error_response = {"detail": "Error message"}
mock_response = MagicMock()
mock_response.status_code = status_code
mock_response.text = "Error message"
mock_response.json.return_value = error_response
mock_httpx_request.return_value = mock_response
# Act
result = BillingService._send_request("DELETE", "/test", json={"key": "value"})
# Assert
assert result == error_response
# Act & Assert
with patch("services.billing_service.logger") as mock_logger:
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("DELETE", "/test", json={"key": "value"})
assert "Unable to process delete request" in str(exc_info.value)
# Verify error logging
mock_logger.error.assert_called_once()
assert "DELETE response" in str(mock_logger.error.call_args)
@pytest.mark.parametrize(
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
@ -210,9 +214,9 @@ class TestBillingServiceSendRequest:
"status_code", [httpx.codes.BAD_REQUEST, httpx.codes.INTERNAL_SERVER_ERROR, httpx.codes.NOT_FOUND]
)
def test_delete_request_non_200_with_invalid_json(self, mock_httpx_request, mock_billing_config, status_code):
"""Test DELETE request with non-200 status code and invalid JSON response raises exception.
"""Test DELETE request with non-200 status code raises ValueError before JSON parsing.
DELETE doesn't check status code, so it calls response.json() which raises JSONDecodeError
DELETE now checks status code before calling response.json(), so ValueError is raised
when the response cannot be parsed as JSON (e.g., empty response).
"""
# Arrange
@ -223,8 +227,13 @@ class TestBillingServiceSendRequest:
mock_httpx_request.return_value = mock_response
# Act & Assert
with pytest.raises(json.JSONDecodeError):
BillingService._send_request("DELETE", "/test", json={"key": "value"})
with patch("services.billing_service.logger") as mock_logger:
with pytest.raises(ValueError) as exc_info:
BillingService._send_request("DELETE", "/test", json={"key": "value"})
assert "Unable to process delete request" in str(exc_info.value)
# Verify error logging
mock_logger.error.assert_called_once()
assert "DELETE response" in str(mock_logger.error.call_args)
def test_retry_on_request_error(self, mock_httpx_request, mock_billing_config):
"""Test that _send_request retries on httpx.RequestError."""
@ -789,7 +798,7 @@ class TestBillingServiceAccountManagement:
# Assert
assert result == expected_response
mock_send_request.assert_called_once_with("DELETE", "/account/", params={"account_id": account_id})
mock_send_request.assert_called_once_with("DELETE", "/account", params={"account_id": account_id})
def test_is_email_in_freeze_true(self, mock_send_request):
"""Test checking if email is frozen (returns True)."""

33
api/uv.lock generated
View File

@ -1633,7 +1633,7 @@ requires-dist = [
{ name = "pandas", extras = ["excel", "output-formatting", "performance"], specifier = "~=2.2.2" },
{ name = "psycogreen", specifier = "~=1.0.2" },
{ name = "psycopg2-binary", specifier = "~=2.9.6" },
{ name = "pycryptodome", specifier = "==3.19.1" },
{ name = "pycryptodome", specifier = "==3.23.0" },
{ name = "pydantic", specifier = "~=2.11.4" },
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
{ name = "pydantic-settings", specifier = "~=2.11.0" },
@ -4796,20 +4796,21 @@ wheels = [
[[package]]
name = "pycryptodome"
version = "3.19.1"
version = "3.23.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/b1/38/42a8855ff1bf568c61ca6557e2203f318fb7afeadaf2eb8ecfdbde107151/pycryptodome-3.19.1.tar.gz", hash = "sha256:8ae0dd1bcfada451c35f9e29a3e5db385caabc190f98e4a80ad02a61098fb776", size = 4782144, upload-time = "2023-12-28T06:52:40.741Z" }
sdist = { url = "https://files.pythonhosted.org/packages/8e/a6/8452177684d5e906854776276ddd34eca30d1b1e15aa1ee9cefc289a33f5/pycryptodome-3.23.0.tar.gz", hash = "sha256:447700a657182d60338bab09fdb27518f8856aecd80ae4c6bdddb67ff5da44ef", size = 4921276, upload-time = "2025-05-17T17:21:45.242Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/ef/4931bc30674f0de0ca0e827b58c8b0c17313a8eae2754976c610b866118b/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_universal2.whl", hash = "sha256:67939a3adbe637281c611596e44500ff309d547e932c449337649921b17b6297", size = 2417027, upload-time = "2023-12-28T06:51:50.138Z" },
{ url = "https://files.pythonhosted.org/packages/67/e6/238c53267fd8d223029c0a0d3730cb1b6594d60f62e40c4184703dc490b1/pycryptodome-3.19.1-cp35-abi3-macosx_10_9_x86_64.whl", hash = "sha256:11ddf6c9b52116b62223b6a9f4741bc4f62bb265392a4463282f7f34bb287180", size = 1579728, upload-time = "2023-12-28T06:51:52.385Z" },
{ url = "https://files.pythonhosted.org/packages/7c/87/7181c42c8d5ba89822a4b824830506d0aeec02959bb893614767e3279846/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e3e6f89480616781d2a7f981472d0cdb09b9da9e8196f43c1234eff45c915766", size = 2051440, upload-time = "2023-12-28T06:51:55.751Z" },
{ url = "https://files.pythonhosted.org/packages/34/dd/332c4c0055527d17dac317ed9f9c864fc047b627d82f4b9a56c110afc6fc/pycryptodome-3.19.1-cp35-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:27e1efcb68993b7ce5d1d047a46a601d41281bba9f1971e6be4aa27c69ab8065", size = 2125379, upload-time = "2023-12-28T06:51:58.567Z" },
{ url = "https://files.pythonhosted.org/packages/24/9e/320b885ea336c218ff54ec2b276cd70ba6904e4f5a14a771ed39a2c47d59/pycryptodome-3.19.1-cp35-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1c6273ca5a03b672e504995529b8bae56da0ebb691d8ef141c4aa68f60765700", size = 2153951, upload-time = "2023-12-28T06:52:01.699Z" },
{ url = "https://files.pythonhosted.org/packages/f4/54/8ae0c43d1257b41bc9d3277c3f875174fd8ad86b9567f0b8609b99c938ee/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:b0bfe61506795877ff974f994397f0c862d037f6f1c0bfc3572195fc00833b96", size = 2044041, upload-time = "2023-12-28T06:52:03.737Z" },
{ url = "https://files.pythonhosted.org/packages/45/93/f8450a92cc38541c3ba1f4cb4e267e15ae6d6678ca617476d52c3a3764d4/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_i686.whl", hash = "sha256:f34976c5c8eb79e14c7d970fb097482835be8d410a4220f86260695ede4c3e17", size = 2182446, upload-time = "2023-12-28T06:52:05.588Z" },
{ url = "https://files.pythonhosted.org/packages/af/cd/ed6e429fb0792ce368f66e83246264dd3a7a045b0b1e63043ed22a063ce5/pycryptodome-3.19.1-cp35-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:7c9e222d0976f68d0cf6409cfea896676ddc1d98485d601e9508f90f60e2b0a2", size = 2144914, upload-time = "2023-12-28T06:52:07.44Z" },
{ url = "https://files.pythonhosted.org/packages/f6/23/b064bd4cfbf2cc5f25afcde0e7c880df5b20798172793137ba4b62d82e72/pycryptodome-3.19.1-cp35-abi3-win32.whl", hash = "sha256:4805e053571140cb37cf153b5c72cd324bb1e3e837cbe590a19f69b6cf85fd03", size = 1713105, upload-time = "2023-12-28T06:52:09.585Z" },
{ url = "https://files.pythonhosted.org/packages/7d/e0/ded1968a5257ab34216a0f8db7433897a2337d59e6d03be113713b346ea2/pycryptodome-3.19.1-cp35-abi3-win_amd64.whl", hash = "sha256:a470237ee71a1efd63f9becebc0ad84b88ec28e6784a2047684b693f458f41b7", size = 1749222, upload-time = "2023-12-28T06:52:11.534Z" },
{ url = "https://files.pythonhosted.org/packages/db/6c/a1f71542c969912bb0e106f64f60a56cc1f0fabecf9396f45accbe63fa68/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:187058ab80b3281b1de11c2e6842a357a1f71b42cb1e15bce373f3d238135c27", size = 2495627, upload-time = "2025-05-17T17:20:47.139Z" },
{ url = "https://files.pythonhosted.org/packages/6e/4e/a066527e079fc5002390c8acdd3aca431e6ea0a50ffd7201551175b47323/pycryptodome-3.23.0-cp37-abi3-macosx_10_9_x86_64.whl", hash = "sha256:cfb5cd445280c5b0a4e6187a7ce8de5a07b5f3f897f235caa11f1f435f182843", size = 1640362, upload-time = "2025-05-17T17:20:50.392Z" },
{ url = "https://files.pythonhosted.org/packages/50/52/adaf4c8c100a8c49d2bd058e5b551f73dfd8cb89eb4911e25a0c469b6b4e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:67bd81fcbe34f43ad9422ee8fd4843c8e7198dd88dd3d40e6de42ee65fbe1490", size = 2182625, upload-time = "2025-05-17T17:20:52.866Z" },
{ url = "https://files.pythonhosted.org/packages/5f/e9/a09476d436d0ff1402ac3867d933c61805ec2326c6ea557aeeac3825604e/pycryptodome-3.23.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c8987bd3307a39bc03df5c8e0e3d8be0c4c3518b7f044b0f4c15d1aa78f52575", size = 2268954, upload-time = "2025-05-17T17:20:55.027Z" },
{ url = "https://files.pythonhosted.org/packages/f9/c5/ffe6474e0c551d54cab931918127c46d70cab8f114e0c2b5a3c071c2f484/pycryptodome-3.23.0-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:aa0698f65e5b570426fc31b8162ed4603b0c2841cbb9088e2b01641e3065915b", size = 2308534, upload-time = "2025-05-17T17:20:57.279Z" },
{ url = "https://files.pythonhosted.org/packages/18/28/e199677fc15ecf43010f2463fde4c1a53015d1fe95fb03bca2890836603a/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:53ecbafc2b55353edcebd64bf5da94a2a2cdf5090a6915bcca6eca6cc452585a", size = 2181853, upload-time = "2025-05-17T17:20:59.322Z" },
{ url = "https://files.pythonhosted.org/packages/ce/ea/4fdb09f2165ce1365c9eaefef36625583371ee514db58dc9b65d3a255c4c/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:156df9667ad9f2ad26255926524e1c136d6664b741547deb0a86a9acf5ea631f", size = 2342465, upload-time = "2025-05-17T17:21:03.83Z" },
{ url = "https://files.pythonhosted.org/packages/22/82/6edc3fc42fe9284aead511394bac167693fb2b0e0395b28b8bedaa07ef04/pycryptodome-3.23.0-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:dea827b4d55ee390dc89b2afe5927d4308a8b538ae91d9c6f7a5090f397af1aa", size = 2267414, upload-time = "2025-05-17T17:21:06.72Z" },
{ url = "https://files.pythonhosted.org/packages/59/fe/aae679b64363eb78326c7fdc9d06ec3de18bac68be4b612fc1fe8902693c/pycryptodome-3.23.0-cp37-abi3-win32.whl", hash = "sha256:507dbead45474b62b2bbe318eb1c4c8ee641077532067fec9c1aa82c31f84886", size = 1768484, upload-time = "2025-05-17T17:21:08.535Z" },
{ url = "https://files.pythonhosted.org/packages/54/2f/e97a1b8294db0daaa87012c24a7bb714147c7ade7656973fd6c736b484ff/pycryptodome-3.23.0-cp37-abi3-win_amd64.whl", hash = "sha256:c75b52aacc6c0c260f204cbdd834f76edc9fb0d8e0da9fbf8352ef58202564e2", size = 1799636, upload-time = "2025-05-17T17:21:10.393Z" },
{ url = "https://files.pythonhosted.org/packages/18/3d/f9441a0d798bf2b1e645adc3265e55706aead1255ccdad3856dbdcffec14/pycryptodome-3.23.0-cp37-abi3-win_arm64.whl", hash = "sha256:11eeeb6917903876f134b56ba11abe95c0b0fd5e3330def218083c7d98bbcb3c", size = 1703675, upload-time = "2025-05-17T17:21:13.146Z" },
]
[[package]]
@ -5003,11 +5004,11 @@ wheels = [
[[package]]
name = "pypdf"
version = "6.6.0"
version = "6.6.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/d8/f4/801632a8b62a805378b6af2b5a3fcbfd8923abf647e0ed1af846a83433b2/pypdf-6.6.0.tar.gz", hash = "sha256:4c887ef2ea38d86faded61141995a3c7d068c9d6ae8477be7ae5de8a8e16592f", size = 5281063, upload-time = "2026-01-09T11:20:11.786Z" }
sdist = { url = "https://files.pythonhosted.org/packages/b8/bb/a44bab1ac3c54dbcf653d7b8bcdee93dddb2d3bf025a3912cacb8149a2f2/pypdf-6.6.2.tar.gz", hash = "sha256:0a3ea3b3303982333404e22d8f75d7b3144f9cf4b2970b96856391a516f9f016", size = 5281850, upload-time = "2026-01-26T11:57:55.964Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/b2/ba/96f99276194f720e74ed99905a080f6e77810558874e8935e580331b46de/pypdf-6.6.0-py3-none-any.whl", hash = "sha256:bca9091ef6de36c7b1a81e09327c554b7ce51e88dad68f5890c2b4a4417f1fd7", size = 328963, upload-time = "2026-01-09T11:20:09.278Z" },
{ url = "https://files.pythonhosted.org/packages/7d/be/549aaf1dfa4ab4aed29b09703d2fb02c4366fc1f05e880948c296c5764b9/pypdf-6.6.2-py3-none-any.whl", hash = "sha256:44c0c9811cfb3b83b28f1c3d054531d5b8b81abaedee0d8cb403650d023832ba", size = 329132, upload-time = "2026-01-26T11:57:54.099Z" },
]
[[package]]

View File

@ -0,0 +1,178 @@
/**
* Tests for multimodal image file handling in chat hooks.
* Tests the file object conversion logic without full hook integration.
*/
describe('Multimodal File Handling', () => {
describe('File type to MIME type mapping', () => {
it('should map image to image/png', () => {
const fileType: string = 'image'
const expectedMime = 'image/png'
const mimeType = fileType === 'image' ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map video to video/mp4', () => {
const fileType: string = 'video'
const expectedMime = 'video/mp4'
const mimeType = fileType === 'video' ? 'video/mp4' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map audio to audio/mpeg', () => {
const fileType: string = 'audio'
const expectedMime = 'audio/mpeg'
const mimeType = fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
it('should map unknown to application/octet-stream', () => {
const fileType: string = 'unknown'
const expectedMime = 'application/octet-stream'
const mimeType = ['image', 'video', 'audio'].includes(fileType) ? 'image/png' : 'application/octet-stream'
expect(mimeType).toBe(expectedMime)
})
})
describe('TransferMethod selection', () => {
it('should select remote_url for images', () => {
const fileType: string = 'image'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('remote_url')
})
it('should select local_file for non-images', () => {
const fileType: string = 'video'
const transferMethod = fileType === 'image' ? 'remote_url' : 'local_file'
expect(transferMethod).toBe('local_file')
})
})
describe('File extension mapping', () => {
it('should use .png extension for images', () => {
const fileType: string = 'image'
const expectedExtension = '.png'
const extension = fileType === 'image' ? 'png' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp4 extension for videos', () => {
const fileType: string = 'video'
const expectedExtension = '.mp4'
const extension = fileType === 'video' ? 'mp4' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
it('should use .mp3 extension for audio', () => {
const fileType: string = 'audio'
const expectedExtension = '.mp3'
const extension = fileType === 'audio' ? 'mp3' : 'bin'
expect(extension).toBe(expectedExtension.replace('.', ''))
})
})
describe('File name generation', () => {
it('should generate correct file name for images', () => {
const fileType: string = 'image'
const expectedName = 'generated_image.png'
const fileName = `generated_${fileType}.${fileType === 'image' ? 'png' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for videos', () => {
const fileType: string = 'video'
const expectedName = 'generated_video.mp4'
const fileName = `generated_${fileType}.${fileType === 'video' ? 'mp4' : 'bin'}`
expect(fileName).toBe(expectedName)
})
it('should generate correct file name for audio', () => {
const fileType: string = 'audio'
const expectedName = 'generated_audio.mp3'
const fileName = `generated_${fileType}.${fileType === 'audio' ? 'mp3' : 'bin'}`
expect(fileName).toBe(expectedName)
})
})
describe('SupportFileType mapping', () => {
it('should map image type to image supportFileType', () => {
const fileType: string = 'image'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('image')
})
it('should map video type to video supportFileType', () => {
const fileType: string = 'video'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('video')
})
it('should map audio type to audio supportFileType', () => {
const fileType: string = 'audio'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('audio')
})
it('should map unknown type to document supportFileType', () => {
const fileType: string = 'unknown'
const supportFileType = fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'
expect(supportFileType).toBe('document')
})
})
describe('File conversion logic', () => {
it('should detect existing transferMethod', () => {
const fileWithTransferMethod = {
id: 'file-123',
transferMethod: 'remote_url' as const,
type: 'image/png',
name: 'test.png',
size: 1024,
supportFileType: 'image',
progress: 100,
}
const hasTransferMethod = 'transferMethod' in fileWithTransferMethod
expect(hasTransferMethod).toBe(true)
})
it('should detect missing transferMethod', () => {
const fileWithoutTransferMethod = {
id: 'file-456',
type: 'image',
url: 'http://example.com/image.png',
belongs_to: 'assistant',
}
const hasTransferMethod = 'transferMethod' in fileWithoutTransferMethod
expect(hasTransferMethod).toBe(false)
})
it('should create file with size 0 for generated files', () => {
const expectedSize = 0
expect(expectedSize).toBe(0)
})
})
describe('Agent vs Non-Agent mode logic', () => {
it('should check for agent_thoughts to determine mode', () => {
const agentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [{}],
}
const isAgentMode = agentResponse.agent_thoughts && agentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(true)
})
it('should detect non-agent mode when agent_thoughts is empty', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {
agent_thoughts: [],
}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBe(false)
})
it('should detect non-agent mode when agent_thoughts is undefined', () => {
const nonAgentResponse: { agent_thoughts?: Array<Record<string, unknown>> } = {}
const isAgentMode = nonAgentResponse.agent_thoughts && nonAgentResponse.agent_thoughts.length > 0
expect(isAgentMode).toBeFalsy()
})
})
})

View File

@ -419,9 +419,40 @@ export const useChat = (
}
},
onFile(file) {
// Convert simple file type to MIME type for non-agent mode
// Backend sends: { id, type: "image", belongs_to, url }
// Frontend expects: { id, type: "image/png", transferMethod, url, uploadedId, supportFileType, name, size }
// Determine file type for MIME conversion
const fileType = (file as { type?: string }).type || 'image'
// If file already has transferMethod, use it as base and ensure all required fields exist
// Otherwise, create a new complete file object
const baseFile = ('transferMethod' in file) ? (file as Partial<FileEntity>) : null
const convertedFile: FileEntity = {
id: baseFile?.id || (file as { id: string }).id,
type: baseFile?.type || (fileType === 'image' ? 'image/png' : fileType === 'video' ? 'video/mp4' : fileType === 'audio' ? 'audio/mpeg' : 'application/octet-stream'),
transferMethod: (baseFile?.transferMethod as FileEntity['transferMethod']) || (fileType === 'image' ? 'remote_url' : 'local_file'),
uploadedId: baseFile?.uploadedId || (file as { id: string }).id,
supportFileType: baseFile?.supportFileType || (fileType === 'image' ? 'image' : fileType === 'video' ? 'video' : fileType === 'audio' ? 'audio' : 'document'),
progress: baseFile?.progress ?? 100,
name: baseFile?.name || `generated_${fileType}.${fileType === 'image' ? 'png' : fileType === 'video' ? 'mp4' : fileType === 'audio' ? 'mp3' : 'bin'}`,
url: baseFile?.url || (file as { url?: string }).url,
size: baseFile?.size ?? 0, // Generated files don't have a known size
}
// For agent mode, add files to the last thought
const lastThought = responseItem.agent_thoughts?.[responseItem.agent_thoughts?.length - 1]
if (lastThought)
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(lastThought as any).message_files, file]
if (lastThought) {
const thought = lastThought as { message_files?: FileEntity[] }
responseItem.agent_thoughts![responseItem.agent_thoughts!.length - 1].message_files = [...(thought.message_files ?? []), convertedFile]
}
// For non-agent mode, add files directly to responseItem.message_files
else {
const currentFiles = (responseItem.message_files as FileEntity[] | undefined) ?? []
responseItem.message_files = [...currentFiles, convertedFile]
}
updateCurrentQAOnTree({
placeholderQuestionId,

View File

@ -1,10 +1,20 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import { createReactI18nextMock } from '@/test/i18n-mock'
import InputWithCopy from './index'
// Mock navigator.clipboard for foxact/use-clipboard
const mockWriteText = vi.fn(() => Promise.resolve())
// Create a controllable mock for useClipboard
const mockCopy = vi.fn()
let mockCopied = false
const mockReset = vi.fn()
vi.mock('foxact/use-clipboard', () => ({
useClipboard: () => ({
copy: mockCopy,
copied: mockCopied,
reset: mockReset,
}),
}))
// Mock the i18n hook with custom translations for test assertions
vi.mock('react-i18next', () => createReactI18nextMock({
@ -17,13 +27,9 @@ vi.mock('react-i18next', () => createReactI18nextMock({
describe('InputWithCopy component', () => {
beforeEach(() => {
vi.clearAllMocks()
mockWriteText.mockClear()
// Setup navigator.clipboard mock
Object.assign(navigator, {
clipboard: {
writeText: mockWriteText,
},
})
mockCopy.mockClear()
mockReset.mockClear()
mockCopied = false
})
it('renders correctly with default props', () => {
@ -44,31 +50,27 @@ describe('InputWithCopy component', () => {
expect(copyButton).not.toBeInTheDocument()
})
it('copies input value when copy button is clicked', async () => {
it('calls copy function with input value when copy button is clicked', () => {
const mockOnChange = vi.fn()
render(<InputWithCopy value="test value" onChange={mockOnChange} />)
const copyButton = screen.getByRole('button')
fireEvent.click(copyButton)
await waitFor(() => {
expect(mockWriteText).toHaveBeenCalledWith('test value')
})
expect(mockCopy).toHaveBeenCalledWith('test value')
})
it('copies custom value when copyValue prop is provided', async () => {
it('calls copy function with custom value when copyValue prop is provided', () => {
const mockOnChange = vi.fn()
render(<InputWithCopy value="display value" onChange={mockOnChange} copyValue="custom copy value" />)
const copyButton = screen.getByRole('button')
fireEvent.click(copyButton)
await waitFor(() => {
expect(mockWriteText).toHaveBeenCalledWith('custom copy value')
})
expect(mockCopy).toHaveBeenCalledWith('custom copy value')
})
it('calls onCopy callback when copy button is clicked', async () => {
it('calls onCopy callback when copy button is clicked', () => {
const onCopyMock = vi.fn()
const mockOnChange = vi.fn()
render(<InputWithCopy value="test value" onChange={mockOnChange} onCopy={onCopyMock} />)
@ -76,25 +78,21 @@ describe('InputWithCopy component', () => {
const copyButton = screen.getByRole('button')
fireEvent.click(copyButton)
await waitFor(() => {
expect(onCopyMock).toHaveBeenCalledWith('test value')
})
expect(onCopyMock).toHaveBeenCalledWith('test value')
})
it('shows copied state after successful copy', async () => {
it('shows copied state when copied is true', () => {
mockCopied = true
const mockOnChange = vi.fn()
render(<InputWithCopy value="test value" onChange={mockOnChange} />)
const copyButton = screen.getByRole('button')
fireEvent.click(copyButton)
// Hover over the button to trigger tooltip
fireEvent.mouseEnter(copyButton)
// Check if the tooltip shows "Copied" state
await waitFor(() => {
expect(screen.getByText('Copied')).toBeInTheDocument()
}, { timeout: 2000 })
// The icon should change to filled version when copied
// We verify the component renders without error in copied state
expect(copyButton).toBeInTheDocument()
})
it('passes through all input props correctly', () => {
@ -117,22 +115,22 @@ describe('InputWithCopy component', () => {
expect(input).toHaveClass('custom-class')
})
it('handles empty value correctly', async () => {
it('handles empty value correctly', () => {
const mockOnChange = vi.fn()
render(<InputWithCopy value="" onChange={mockOnChange} />)
const input = screen.getByDisplayValue('')
const input = screen.getByRole('textbox')
const copyButton = screen.getByRole('button')
expect(input).toBeInTheDocument()
expect(input).toHaveValue('')
expect(copyButton).toBeInTheDocument()
// Clicking copy button with empty value should call copy with empty string
fireEvent.click(copyButton)
await waitFor(() => {
expect(mockWriteText).toHaveBeenCalledWith('')
})
expect(mockCopy).toHaveBeenCalledWith('')
})
it('maintains focus on input after copy', async () => {
it('maintains focus on input after copy', () => {
const mockOnChange = vi.fn()
render(<InputWithCopy value="test value" onChange={mockOnChange} />)

View File

@ -0,0 +1,426 @@
import type { DefaultModelResponse, Model, ModelItem } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { RetrievalConfig } from '@/types/app'
import { describe, expect, it } from 'vitest'
import { ConfigurationMethodEnum, ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { RerankingModeEnum } from '@/models/datasets'
import { RETRIEVE_METHOD } from '@/types/app'
import { ensureRerankModelSelected, isReRankModelSelected } from './check-rerank-model'
// Test data factory
const createRetrievalConfig = (overrides: Partial<RetrievalConfig> = {}): RetrievalConfig => ({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
top_k: 3,
score_threshold_enabled: false,
score_threshold: 0.5,
...overrides,
})
const createModelItem = (model: string): ModelItem => ({
model,
label: { en_US: model, zh_Hans: model },
model_type: ModelTypeEnum.rerank,
fetch_from: ConfigurationMethodEnum.predefinedModel,
status: ModelStatusEnum.active,
model_properties: {},
load_balancing_enabled: false,
})
const createRerankModelList = (): Model[] => [
{
provider: 'openai',
icon_small: { en_US: '', zh_Hans: '' },
label: { en_US: 'OpenAI', zh_Hans: 'OpenAI' },
models: [
createModelItem('gpt-4-turbo'),
createModelItem('gpt-3.5-turbo'),
],
status: ModelStatusEnum.active,
},
{
provider: 'cohere',
icon_small: { en_US: '', zh_Hans: '' },
label: { en_US: 'Cohere', zh_Hans: 'Cohere' },
models: [
createModelItem('rerank-english-v2.0'),
createModelItem('rerank-multilingual-v2.0'),
],
status: ModelStatusEnum.active,
},
]
const createDefaultRerankModel = (): DefaultModelResponse => ({
model: 'rerank-english-v2.0',
model_type: ModelTypeEnum.rerank,
provider: {
provider: 'cohere',
icon_small: { en_US: '', zh_Hans: '' },
},
})
describe('check-rerank-model', () => {
describe('isReRankModelSelected', () => {
describe('Core Functionality', () => {
it('should return true when reranking is disabled', () => {
const config = createRetrievalConfig({
reranking_enable: false,
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(true)
})
it('should return true for economy indexMethod', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'economy',
})
expect(result).toBe(true)
})
it('should return true when model is selected and valid', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
reranking_model: {
reranking_provider_name: 'cohere',
reranking_model_name: 'rerank-english-v2.0',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(true)
})
})
describe('Edge Cases', () => {
it('should return false when reranking enabled but no model selected for semantic search', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(false)
})
it('should return false when reranking enabled but no model selected for fullText search', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.fullText,
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(false)
})
it('should return false for hybrid search without WeightedScore mode and no model selected', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: true,
reranking_mode: RerankingModeEnum.RerankingModel,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(false)
})
it('should return true for hybrid search with WeightedScore mode even without model', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: true,
reranking_mode: RerankingModeEnum.WeightedScore,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(true)
})
it('should return false when provider exists but model not found', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
reranking_model: {
reranking_provider_name: 'cohere',
reranking_model_name: 'non-existent-model',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(false)
})
it('should return false when provider not found in list', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
reranking_model: {
reranking_provider_name: 'non-existent-provider',
reranking_model_name: 'some-model',
},
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: 'high_quality',
})
expect(result).toBe(false)
})
it('should return true with empty rerankModelList when reranking disabled', () => {
const config = createRetrievalConfig({
reranking_enable: false,
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: [],
indexMethod: 'high_quality',
})
expect(result).toBe(true)
})
it('should return true when indexMethod is undefined', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
})
const result = isReRankModelSelected({
retrievalConfig: config,
rerankModelList: createRerankModelList(),
indexMethod: undefined,
})
expect(result).toBe(true)
})
})
})
describe('ensureRerankModelSelected', () => {
describe('Core Functionality', () => {
it('should return original config when reranking model already selected', () => {
const config = createRetrievalConfig({
reranking_enable: true,
reranking_model: {
reranking_provider_name: 'cohere',
reranking_model_name: 'rerank-english-v2.0',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'high_quality',
})
expect(result).toEqual(config)
})
it('should apply default model when reranking enabled but no model selected', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'high_quality',
})
expect(result.reranking_model).toEqual({
reranking_provider_name: 'cohere',
reranking_model_name: 'rerank-english-v2.0',
})
})
it('should apply default model for hybrid search method', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.hybrid,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'high_quality',
})
expect(result.reranking_model).toEqual({
reranking_provider_name: 'cohere',
reranking_model_name: 'rerank-english-v2.0',
})
})
})
describe('Edge Cases', () => {
it('should return original config when indexMethod is not high_quality', () => {
const config = createRetrievalConfig({
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'economy',
})
expect(result).toEqual(config)
})
it('should return original config when rerankDefaultModel is null', () => {
const config = createRetrievalConfig({
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: null as unknown as DefaultModelResponse,
indexMethod: 'high_quality',
})
expect(result).toEqual(config)
})
it('should return original config when reranking disabled and not hybrid search', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'high_quality',
})
expect(result).toEqual(config)
})
it('should return original config when indexMethod is undefined', () => {
const config = createRetrievalConfig({
reranking_enable: true,
reranking_model: {
reranking_provider_name: '',
reranking_model_name: '',
},
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: undefined,
})
expect(result).toEqual(config)
})
it('should preserve other config properties when applying default model', () => {
const config = createRetrievalConfig({
search_method: RETRIEVE_METHOD.semantic,
reranking_enable: true,
top_k: 10,
score_threshold_enabled: true,
score_threshold: 0.8,
})
const result = ensureRerankModelSelected({
retrievalConfig: config,
rerankDefaultModel: createDefaultRerankModel(),
indexMethod: 'high_quality',
})
expect(result.top_k).toBe(10)
expect(result.score_threshold_enabled).toBe(true)
expect(result.score_threshold).toBe(0.8)
expect(result.search_method).toBe(RETRIEVE_METHOD.semantic)
})
})
})
})

View File

@ -0,0 +1,61 @@
import { render, screen } from '@testing-library/react'
import { describe, expect, it } from 'vitest'
import ChunkingModeLabel from './chunking-mode-label'
describe('ChunkingModeLabel', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
expect(screen.getByText(/general/i)).toBeInTheDocument()
})
it('should render with Badge wrapper', () => {
const { container } = render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
// Badge component renders with specific styles
expect(container.querySelector('.flex')).toBeInTheDocument()
})
})
describe('Props', () => {
it('should display general mode text when isGeneralMode is true', () => {
render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
expect(screen.getByText(/general/i)).toBeInTheDocument()
})
it('should display parent-child mode text when isGeneralMode is false', () => {
render(<ChunkingModeLabel isGeneralMode={false} isQAMode={false} />)
expect(screen.getByText(/parentChild/i)).toBeInTheDocument()
})
it('should append QA suffix when isGeneralMode and isQAMode are both true', () => {
render(<ChunkingModeLabel isGeneralMode={true} isQAMode={true} />)
expect(screen.getByText(/general.*QA/i)).toBeInTheDocument()
})
it('should not append QA suffix when isGeneralMode is true but isQAMode is false', () => {
render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
const text = screen.getByText(/general/i)
expect(text.textContent).not.toContain('QA')
})
it('should not display QA suffix for parent-child mode even when isQAMode is true', () => {
render(<ChunkingModeLabel isGeneralMode={false} isQAMode={true} />)
expect(screen.getByText(/parentChild/i)).toBeInTheDocument()
expect(screen.queryByText(/QA/i)).not.toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should render icon element', () => {
const { container } = render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
const iconElement = container.querySelector('svg')
expect(iconElement).toBeInTheDocument()
})
it('should apply correct icon size classes', () => {
const { container } = render(<ChunkingModeLabel isGeneralMode={true} isQAMode={false} />)
const iconElement = container.querySelector('svg')
expect(iconElement).toHaveClass('h-3', 'w-3')
})
})
})

View File

@ -0,0 +1,136 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it } from 'vitest'
import { CredentialIcon } from './credential-icon'
describe('CredentialIcon', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
render(<CredentialIcon name="Test" />)
expect(screen.getByText('T')).toBeInTheDocument()
})
it('should render first letter when no avatar provided', () => {
render(<CredentialIcon name="Alice" />)
expect(screen.getByText('A')).toBeInTheDocument()
})
it('should render image when avatarUrl is provided', () => {
render(<CredentialIcon name="Test" avatarUrl="https://example.com/avatar.png" />)
const img = screen.getByRole('img')
expect(img).toBeInTheDocument()
expect(img).toHaveAttribute('src', 'https://example.com/avatar.png')
})
})
describe('Props', () => {
it('should apply default size of 20px', () => {
const { container } = render(<CredentialIcon name="Test" />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveStyle({ width: '20px', height: '20px' })
})
it('should apply custom size', () => {
const { container } = render(<CredentialIcon name="Test" size={40} />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveStyle({ width: '40px', height: '40px' })
})
it('should apply custom className', () => {
const { container } = render(<CredentialIcon name="Test" className="custom-class" />)
const wrapper = container.firstChild as HTMLElement
expect(wrapper).toHaveClass('custom-class')
})
it('should uppercase the first letter', () => {
render(<CredentialIcon name="bob" />)
expect(screen.getByText('B')).toBeInTheDocument()
})
it('should render fallback when avatarUrl is "default"', () => {
render(<CredentialIcon name="Test" avatarUrl="default" />)
expect(screen.getByText('T')).toBeInTheDocument()
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should fallback to letter when image fails to load', () => {
render(<CredentialIcon name="Test" avatarUrl="https://example.com/broken.png" />)
// Initially shows image
const img = screen.getByRole('img')
expect(img).toBeInTheDocument()
// Trigger error event
fireEvent.error(img)
// Should now show letter fallback
expect(screen.getByText('T')).toBeInTheDocument()
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should handle single character name', () => {
render(<CredentialIcon name="A" />)
expect(screen.getByText('A')).toBeInTheDocument()
})
it('should handle name starting with number', () => {
render(<CredentialIcon name="123test" />)
expect(screen.getByText('1')).toBeInTheDocument()
})
it('should handle name starting with special character', () => {
render(<CredentialIcon name="@user" />)
expect(screen.getByText('@')).toBeInTheDocument()
})
it('should assign consistent background colors based on first letter', () => {
// Same first letter should get same color
const { container: container1 } = render(<CredentialIcon name="Alice" />)
const { container: container2 } = render(<CredentialIcon name="Anna" />)
const wrapper1 = container1.firstChild as HTMLElement
const wrapper2 = container2.firstChild as HTMLElement
// Both should have the same bg class since they start with 'A'
const classes1 = wrapper1.className
const classes2 = wrapper2.className
const bgClass1 = classes1.match(/bg-components-icon-bg-\S+/)?.[0]
const bgClass2 = classes2.match(/bg-components-icon-bg-\S+/)?.[0]
expect(bgClass1).toBe(bgClass2)
})
it('should apply different background colors for different letters', () => {
// 'A' (65) % 4 = 1 → pink, 'B' (66) % 4 = 2 → indigo
const { container: container1 } = render(<CredentialIcon name="Alice" />)
const { container: container2 } = render(<CredentialIcon name="Bob" />)
const wrapper1 = container1.firstChild as HTMLElement
const wrapper2 = container2.firstChild as HTMLElement
const bgClass1 = wrapper1.className.match(/bg-components-icon-bg-\S+/)?.[0]
const bgClass2 = wrapper2.className.match(/bg-components-icon-bg-\S+/)?.[0]
expect(bgClass1).toBeDefined()
expect(bgClass2).toBeDefined()
expect(bgClass1).not.toBe(bgClass2)
})
it('should handle empty avatarUrl string', () => {
render(<CredentialIcon name="Test" avatarUrl="" />)
expect(screen.getByText('T')).toBeInTheDocument()
expect(screen.queryByRole('img')).not.toBeInTheDocument()
})
it('should render image with correct dimensions', () => {
render(<CredentialIcon name="Test" avatarUrl="https://example.com/avatar.png" size={32} />)
const img = screen.getByRole('img')
expect(img).toHaveAttribute('width', '32')
expect(img).toHaveAttribute('height', '32')
})
})
})

View File

@ -0,0 +1,115 @@
import { render } from '@testing-library/react'
import { describe, expect, it } from 'vitest'
import DocumentFileIcon from './document-file-icon'
describe('DocumentFileIcon', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
const { container } = render(<DocumentFileIcon />)
expect(container.firstChild).toBeInTheDocument()
})
it('should render FileTypeIcon component', () => {
const { container } = render(<DocumentFileIcon extension="pdf" />)
// FileTypeIcon renders an svg or img element
expect(container.querySelector('svg, img')).toBeInTheDocument()
})
})
describe('Props', () => {
it('should determine type from extension prop', () => {
const { container } = render(<DocumentFileIcon extension="pdf" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should determine type from name when extension not provided', () => {
const { container } = render(<DocumentFileIcon name="document.pdf" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle uppercase extension', () => {
const { container } = render(<DocumentFileIcon extension="PDF" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle uppercase name extension', () => {
const { container } = render(<DocumentFileIcon name="DOCUMENT.PDF" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should apply custom className', () => {
const { container } = render(<DocumentFileIcon extension="pdf" className="custom-icon" />)
expect(container.querySelector('.custom-icon')).toBeInTheDocument()
})
it('should pass size prop to FileTypeIcon', () => {
// Testing different size values
const { container: smContainer } = render(<DocumentFileIcon extension="pdf" size="sm" />)
const { container: lgContainer } = render(<DocumentFileIcon extension="pdf" size="lg" />)
expect(smContainer.firstChild).toBeInTheDocument()
expect(lgContainer.firstChild).toBeInTheDocument()
})
})
describe('File Type Mapping', () => {
const testCases = [
{ extension: 'pdf', description: 'PDF files' },
{ extension: 'json', description: 'JSON files' },
{ extension: 'html', description: 'HTML files' },
{ extension: 'txt', description: 'TXT files' },
{ extension: 'markdown', description: 'Markdown files' },
{ extension: 'md', description: 'MD files' },
{ extension: 'xlsx', description: 'XLSX files' },
{ extension: 'xls', description: 'XLS files' },
{ extension: 'csv', description: 'CSV files' },
{ extension: 'doc', description: 'DOC files' },
{ extension: 'docx', description: 'DOCX files' },
]
testCases.forEach(({ extension, description }) => {
it(`should handle ${description}`, () => {
const { container } = render(<DocumentFileIcon extension={extension} />)
expect(container.firstChild).toBeInTheDocument()
})
})
})
describe('Edge Cases', () => {
it('should handle unknown extension with default document type', () => {
const { container } = render(<DocumentFileIcon extension="xyz" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle empty extension string', () => {
const { container } = render(<DocumentFileIcon extension="" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle name without extension', () => {
const { container } = render(<DocumentFileIcon name="document" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle name with multiple dots', () => {
const { container } = render(<DocumentFileIcon name="my.document.file.pdf" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should prioritize extension over name', () => {
// If both are provided, extension should take precedence
const { container } = render(<DocumentFileIcon extension="xlsx" name="document.pdf" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle undefined extension and name', () => {
const { container } = render(<DocumentFileIcon />)
expect(container.firstChild).toBeInTheDocument()
})
it('should apply default size of md', () => {
const { container } = render(<DocumentFileIcon extension="pdf" />)
expect(container.firstChild).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,166 @@
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import { useAutoDisabledDocuments } from '@/service/knowledge/use-document'
import AutoDisabledDocument from './auto-disabled-document'
type AutoDisabledDocumentsResponse = { document_ids: string[] }
const createMockQueryResult = (
data: AutoDisabledDocumentsResponse | undefined,
isLoading: boolean,
) => ({
data,
isLoading,
}) as ReturnType<typeof useAutoDisabledDocuments>
// Mock service hooks
const mockMutateAsync = vi.fn()
const mockInvalidDisabledDocument = vi.fn()
vi.mock('@/service/knowledge/use-document', () => ({
useAutoDisabledDocuments: vi.fn(),
useDocumentEnable: vi.fn(() => ({
mutateAsync: mockMutateAsync,
})),
useInvalidDisabledDocument: vi.fn(() => mockInvalidDisabledDocument),
}))
// Mock Toast
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
const mockUseAutoDisabledDocuments = vi.mocked(useAutoDisabledDocuments)
describe('AutoDisabledDocument', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMutateAsync.mockResolvedValue({})
})
describe('Rendering', () => {
it('should render nothing when loading', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult(undefined, true),
)
const { container } = render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should render nothing when no disabled documents', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: [] }, false),
)
const { container } = render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should render nothing when document_ids is undefined', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult(undefined, false),
)
const { container } = render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should render StatusWithAction when disabled documents exist', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1', 'doc2'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(screen.getByText(/enable/i)).toBeInTheDocument()
})
})
describe('Props', () => {
it('should pass datasetId to useAutoDisabledDocuments', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: [] }, false),
)
render(<AutoDisabledDocument datasetId="my-dataset-id" />)
expect(mockUseAutoDisabledDocuments).toHaveBeenCalledWith('my-dataset-id')
})
})
describe('User Interactions', () => {
it('should call enableDocument when action button is clicked', async () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1', 'doc2'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
const actionButton = screen.getByText(/enable/i)
fireEvent.click(actionButton)
await waitFor(() => {
expect(mockMutateAsync).toHaveBeenCalledWith({
datasetId: 'test-dataset',
documentIds: ['doc1', 'doc2'],
})
})
})
it('should invalidate cache after enabling documents', async () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
const actionButton = screen.getByText(/enable/i)
fireEvent.click(actionButton)
await waitFor(() => {
expect(mockInvalidDisabledDocument).toHaveBeenCalled()
})
})
it('should show success toast after enabling documents', async () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
const actionButton = screen.getByText(/enable/i)
fireEvent.click(actionButton)
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'success',
message: expect.any(String),
})
})
})
})
describe('Edge Cases', () => {
it('should handle single disabled document', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(screen.getByText(/enable/i)).toBeInTheDocument()
})
it('should handle multiple disabled documents', () => {
mockUseAutoDisabledDocuments.mockReturnValue(
createMockQueryResult({ document_ids: ['doc1', 'doc2', 'doc3', 'doc4', 'doc5'] }, false),
)
render(<AutoDisabledDocument datasetId="test-dataset" />)
expect(screen.getByText(/enable/i)).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,280 @@
import type { ErrorDocsResponse } from '@/models/datasets'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { retryErrorDocs } from '@/service/datasets'
import { useDatasetErrorDocs } from '@/service/knowledge/use-dataset'
import RetryButton from './index-failed'
// Mock service hooks
const mockRefetch = vi.fn()
vi.mock('@/service/knowledge/use-dataset', () => ({
useDatasetErrorDocs: vi.fn(),
}))
vi.mock('@/service/datasets', () => ({
retryErrorDocs: vi.fn(),
}))
const mockUseDatasetErrorDocs = vi.mocked(useDatasetErrorDocs)
const mockRetryErrorDocs = vi.mocked(retryErrorDocs)
// Helper to create mock query result
const createMockQueryResult = (
data: ErrorDocsResponse | undefined,
isLoading: boolean,
) => ({
data,
isLoading,
refetch: mockRefetch,
// Required query result properties
error: null,
isError: false,
isFetched: true,
isFetching: false,
isSuccess: !isLoading && !!data,
status: isLoading ? 'pending' : 'success',
dataUpdatedAt: Date.now(),
errorUpdatedAt: 0,
failureCount: 0,
failureReason: null,
errorUpdateCount: 0,
isLoadingError: false,
isPaused: false,
isPlaceholderData: false,
isPending: isLoading,
isRefetchError: false,
isRefetching: false,
isStale: false,
fetchStatus: 'idle',
promise: Promise.resolve(data as ErrorDocsResponse),
isFetchedAfterMount: true,
isInitialLoading: false,
}) as unknown as ReturnType<typeof useDatasetErrorDocs>
describe('RetryButton (IndexFailed)', () => {
beforeEach(() => {
vi.clearAllMocks()
mockRefetch.mockResolvedValue({})
})
describe('Rendering', () => {
it('should render nothing when loading', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult(undefined, true),
)
const { container } = render(<RetryButton datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should render nothing when no error documents', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({ total: 0, data: [] }, false),
)
const { container } = render(<RetryButton datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should render StatusWithAction when error documents exist', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 3,
data: [
{ id: 'doc1' },
{ id: 'doc2' },
{ id: 'doc3' },
] as ErrorDocsResponse['data'],
}, false),
)
render(<RetryButton datasetId="test-dataset" />)
expect(screen.getByText(/retry/i)).toBeInTheDocument()
})
it('should display error count in description', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 5,
data: [{ id: 'doc1' }] as ErrorDocsResponse['data'],
}, false),
)
render(<RetryButton datasetId="test-dataset" />)
expect(screen.getByText(/5/)).toBeInTheDocument()
})
})
describe('Props', () => {
it('should pass datasetId to useDatasetErrorDocs', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({ total: 0, data: [] }, false),
)
render(<RetryButton datasetId="my-dataset-id" />)
expect(mockUseDatasetErrorDocs).toHaveBeenCalledWith('my-dataset-id')
})
})
describe('User Interactions', () => {
it('should call retryErrorDocs when retry button is clicked', async () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 2,
data: [{ id: 'doc1' }, { id: 'doc2' }] as ErrorDocsResponse['data'],
}, false),
)
mockRetryErrorDocs.mockResolvedValue({ result: 'success' })
render(<RetryButton datasetId="test-dataset" />)
const retryButton = screen.getByText(/retry/i)
fireEvent.click(retryButton)
await waitFor(() => {
expect(mockRetryErrorDocs).toHaveBeenCalledWith({
datasetId: 'test-dataset',
document_ids: ['doc1', 'doc2'],
})
})
})
it('should refetch error docs after successful retry', async () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 1,
data: [{ id: 'doc1' }] as ErrorDocsResponse['data'],
}, false),
)
mockRetryErrorDocs.mockResolvedValue({ result: 'success' })
render(<RetryButton datasetId="test-dataset" />)
const retryButton = screen.getByText(/retry/i)
fireEvent.click(retryButton)
await waitFor(() => {
expect(mockRefetch).toHaveBeenCalled()
})
})
it('should disable button while retrying', async () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 1,
data: [{ id: 'doc1' }] as ErrorDocsResponse['data'],
}, false),
)
// Delay the response to test loading state
mockRetryErrorDocs.mockImplementation(() => new Promise(resolve => setTimeout(() => resolve({ result: 'success' }), 100)))
render(<RetryButton datasetId="test-dataset" />)
const retryButton = screen.getByText(/retry/i)
fireEvent.click(retryButton)
// Button should show disabled styling during retry
await waitFor(() => {
const button = screen.getByText(/retry/i)
expect(button).toHaveClass('cursor-not-allowed')
expect(button).toHaveClass('text-text-disabled')
})
})
})
describe('State Management', () => {
it('should transition to error state when retry fails', async () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 1,
data: [{ id: 'doc1' }] as ErrorDocsResponse['data'],
}, false),
)
mockRetryErrorDocs.mockResolvedValue({ result: 'fail' })
render(<RetryButton datasetId="test-dataset" />)
const retryButton = screen.getByText(/retry/i)
fireEvent.click(retryButton)
await waitFor(() => {
// Button should still be visible after failed retry
expect(screen.getByText(/retry/i)).toBeInTheDocument()
})
})
it('should transition to success state when total becomes 0', async () => {
const { rerender } = render(<RetryButton datasetId="test-dataset" />)
// Initially has errors
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({
total: 1,
data: [{ id: 'doc1' }] as ErrorDocsResponse['data'],
}, false),
)
rerender(<RetryButton datasetId="test-dataset" />)
expect(screen.getByText(/retry/i)).toBeInTheDocument()
// Now no errors
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({ total: 0, data: [] }, false),
)
rerender(<RetryButton datasetId="test-dataset" />)
await waitFor(() => {
expect(screen.queryByText(/retry/i)).not.toBeInTheDocument()
})
})
})
describe('Edge Cases', () => {
it('should handle empty data array', () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({ total: 0, data: [] }, false),
)
const { container } = render(<RetryButton datasetId="test-dataset" />)
expect(container.firstChild).toBeNull()
})
it('should handle undefined data by showing error state', () => {
// When data is undefined but not loading, the component shows error state
// because errorDocs?.total is not strictly equal to 0
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult(undefined, false),
)
render(<RetryButton datasetId="test-dataset" />)
// Component renders with undefined count
expect(screen.getByText(/retry/i)).toBeInTheDocument()
})
it('should handle retry with empty document list', async () => {
mockUseDatasetErrorDocs.mockReturnValue(
createMockQueryResult({ total: 1, data: [] }, false),
)
mockRetryErrorDocs.mockResolvedValue({ result: 'success' })
render(<RetryButton datasetId="test-dataset" />)
const retryButton = screen.getByText(/retry/i)
fireEvent.click(retryButton)
await waitFor(() => {
expect(mockRetryErrorDocs).toHaveBeenCalledWith({
datasetId: 'test-dataset',
document_ids: [],
})
})
})
})
})

View File

@ -0,0 +1,175 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import StatusWithAction from './status-with-action'
describe('StatusWithAction', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
render(<StatusWithAction description="Test description" />)
expect(screen.getByText('Test description')).toBeInTheDocument()
})
it('should render description text', () => {
render(<StatusWithAction description="This is a test message" />)
expect(screen.getByText('This is a test message')).toBeInTheDocument()
})
it('should render icon based on type', () => {
const { container } = render(<StatusWithAction type="success" description="Success" />)
expect(container.querySelector('svg')).toBeInTheDocument()
})
})
describe('Props', () => {
it('should default to info type when type is not provided', () => {
const { container } = render(<StatusWithAction description="Default type" />)
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-accent')
})
it('should render success type with correct color', () => {
const { container } = render(<StatusWithAction type="success" description="Success" />)
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-success')
})
it('should render error type with correct color', () => {
const { container } = render(<StatusWithAction type="error" description="Error" />)
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-destructive')
})
it('should render warning type with correct color', () => {
const { container } = render(<StatusWithAction type="warning" description="Warning" />)
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-warning-secondary')
})
it('should render info type with correct color', () => {
const { container } = render(<StatusWithAction type="info" description="Info" />)
const icon = container.querySelector('svg')
expect(icon).toHaveClass('text-text-accent')
})
it('should render action button when actionText and onAction are provided', () => {
const onAction = vi.fn()
render(
<StatusWithAction
description="Test"
actionText="Click me"
onAction={onAction}
/>,
)
expect(screen.getByText('Click me')).toBeInTheDocument()
})
it('should not render action button when onAction is not provided', () => {
render(<StatusWithAction description="Test" actionText="Click me" />)
expect(screen.queryByText('Click me')).not.toBeInTheDocument()
})
it('should render divider when action is present', () => {
const { container } = render(
<StatusWithAction
description="Test"
actionText="Click me"
onAction={() => {}}
/>,
)
// Divider component renders a div with specific classes
expect(container.querySelector('.bg-divider-regular')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call onAction when action button is clicked', () => {
const onAction = vi.fn()
render(
<StatusWithAction
description="Test"
actionText="Click me"
onAction={onAction}
/>,
)
fireEvent.click(screen.getByText('Click me'))
expect(onAction).toHaveBeenCalledTimes(1)
})
it('should call onAction even when disabled (style only)', () => {
// Note: disabled prop only affects styling, not actual click behavior
const onAction = vi.fn()
render(
<StatusWithAction
description="Test"
actionText="Click me"
onAction={onAction}
disabled
/>,
)
fireEvent.click(screen.getByText('Click me'))
expect(onAction).toHaveBeenCalledTimes(1)
})
it('should apply disabled styles when disabled prop is true', () => {
render(
<StatusWithAction
description="Test"
actionText="Click me"
onAction={() => {}}
disabled
/>,
)
const actionButton = screen.getByText('Click me')
expect(actionButton).toHaveClass('cursor-not-allowed')
expect(actionButton).toHaveClass('text-text-disabled')
})
})
describe('Status Background Gradients', () => {
it('should apply success gradient background', () => {
const { container } = render(<StatusWithAction type="success" description="Success" />)
const gradientDiv = container.querySelector('.opacity-40')
expect(gradientDiv?.className).toContain('rgba(23,178,106,0.25)')
})
it('should apply warning gradient background', () => {
const { container } = render(<StatusWithAction type="warning" description="Warning" />)
const gradientDiv = container.querySelector('.opacity-40')
expect(gradientDiv?.className).toContain('rgba(247,144,9,0.25)')
})
it('should apply error gradient background', () => {
const { container } = render(<StatusWithAction type="error" description="Error" />)
const gradientDiv = container.querySelector('.opacity-40')
expect(gradientDiv?.className).toContain('rgba(240,68,56,0.25)')
})
it('should apply info gradient background', () => {
const { container } = render(<StatusWithAction type="info" description="Info" />)
const gradientDiv = container.querySelector('.opacity-40')
expect(gradientDiv?.className).toContain('rgba(11,165,236,0.25)')
})
})
describe('Edge Cases', () => {
it('should handle empty description', () => {
const { container } = render(<StatusWithAction description="" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle long description text', () => {
const longText = 'A'.repeat(500)
render(<StatusWithAction description={longText} />)
expect(screen.getByText(longText)).toBeInTheDocument()
})
it('should handle undefined actionText when onAction is provided', () => {
render(<StatusWithAction description="Test" onAction={() => {}} />)
// Should render without throwing
expect(screen.getByText('Test')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,252 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import ImageList from './index'
// Track handleImageClick calls for testing
type FileEntity = {
sourceUrl: string
name: string
mimeType?: string
size?: number
extension?: string
}
let capturedOnClick: ((file: FileEntity) => void) | null = null
// Mock FileThumb to capture click handler
vi.mock('@/app/components/base/file-thumb', () => ({
default: ({ file, onClick }: { file: FileEntity, onClick?: (file: FileEntity) => void }) => {
// Capture the onClick for testing
capturedOnClick = onClick ?? null
return (
<div
data-testid={`file-thumb-${file.sourceUrl}`}
className="cursor-pointer"
onClick={() => onClick?.(file)}
>
{file.name}
</div>
)
},
}))
type ImagePreviewerProps = {
images: ImageInfo[]
initialIndex: number
onClose: () => void
}
type ImageInfo = {
url: string
name: string
size: number
}
// Mock ImagePreviewer since it uses createPortal
vi.mock('../image-previewer', () => ({
default: ({ images, initialIndex, onClose }: ImagePreviewerProps) => (
<div data-testid="image-previewer">
<span data-testid="preview-count">{images.length}</span>
<span data-testid="preview-index">{initialIndex}</span>
<button data-testid="close-preview" onClick={onClose}>Close</button>
</div>
),
}))
const createMockImages = (count: number) => {
return Array.from({ length: count }, (_, i) => ({
name: `image-${i + 1}.png`,
mimeType: 'image/png',
sourceUrl: `https://example.com/image-${i + 1}.png`,
size: 1024 * (i + 1),
extension: 'png',
}))
}
describe('ImageList', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', () => {
const images = createMockImages(3)
const { container } = render(<ImageList images={images} size="md" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should render all images when count is below limit', () => {
const images = createMockImages(5)
render(<ImageList images={images} size="md" limit={9} />)
// Each image renders a FileThumb component
const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]')
expect(thumbnails.length).toBeGreaterThanOrEqual(5)
})
it('should render limited images when count exceeds limit', () => {
const images = createMockImages(15)
render(<ImageList images={images} size="md" limit={9} />)
// More button should be visible
expect(screen.getByText(/\+6/)).toBeInTheDocument()
})
})
describe('Props', () => {
it('should apply custom className', () => {
const images = createMockImages(3)
const { container } = render(
<ImageList images={images} size="md" className="custom-class" />,
)
expect(container.firstChild).toHaveClass('custom-class')
})
it('should use default limit of 9', () => {
const images = createMockImages(12)
render(<ImageList images={images} size="md" />)
// Should show "+3" for remaining images
expect(screen.getByText(/\+3/)).toBeInTheDocument()
})
it('should respect custom limit', () => {
const images = createMockImages(10)
render(<ImageList images={images} size="md" limit={5} />)
// Should show "+5" for remaining images
expect(screen.getByText(/\+5/)).toBeInTheDocument()
})
it('should handle size prop sm', () => {
const images = createMockImages(2)
const { container } = render(<ImageList images={images} size="sm" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should handle size prop md', () => {
const images = createMockImages(2)
const { container } = render(<ImageList images={images} size="md" />)
expect(container.firstChild).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should show all images when More button is clicked', () => {
const images = createMockImages(15)
render(<ImageList images={images} size="md" limit={9} />)
// Click More button
const moreButton = screen.getByText(/\+6/)
fireEvent.click(moreButton)
// More button should disappear
expect(screen.queryByText(/\+6/)).not.toBeInTheDocument()
})
it('should open preview when image is clicked', () => {
const images = createMockImages(3)
render(<ImageList images={images} size="md" />)
// Find and click an image thumbnail
const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]')
if (thumbnails.length > 0) {
fireEvent.click(thumbnails[0])
// Preview should open
expect(screen.getByTestId('image-previewer')).toBeInTheDocument()
}
})
it('should close preview when close button is clicked', () => {
const images = createMockImages(3)
render(<ImageList images={images} size="md" />)
// Open preview
const thumbnails = document.querySelectorAll('[class*="cursor-pointer"]')
if (thumbnails.length > 0) {
fireEvent.click(thumbnails[0])
// Close preview
const closeButton = screen.getByTestId('close-preview')
fireEvent.click(closeButton)
// Preview should be closed
expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument()
}
})
})
describe('Edge Cases', () => {
it('should handle empty images array', () => {
const { container } = render(<ImageList images={[]} size="md" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should not open preview when clicked image not found in list (index === -1)', () => {
const images = createMockImages(3)
const { rerender } = render(<ImageList images={images} size="md" />)
// Click first image to open preview
const firstThumb = screen.getByTestId('file-thumb-https://example.com/image-1.png')
fireEvent.click(firstThumb)
// Preview should open for valid image
expect(screen.getByTestId('image-previewer')).toBeInTheDocument()
// Close preview
fireEvent.click(screen.getByTestId('close-preview'))
expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument()
// Now render with images that don't include the previously clicked one
const newImages = createMockImages(2) // Only 2 images
rerender(<ImageList images={newImages} size="md" />)
// Click on a thumbnail that exists
const validThumb = screen.getByTestId('file-thumb-https://example.com/image-1.png')
fireEvent.click(validThumb)
expect(screen.getByTestId('image-previewer')).toBeInTheDocument()
})
it('should return early when file sourceUrl is not found in limitedImages (index === -1)', () => {
const images = createMockImages(3)
render(<ImageList images={images} size="md" />)
// Call the captured onClick with a file that has a non-matching sourceUrl
// This triggers the index === -1 branch (line 44-45)
if (capturedOnClick) {
capturedOnClick({
name: 'nonexistent.png',
mimeType: 'image/png',
sourceUrl: 'https://example.com/nonexistent.png', // Not in the list
size: 1024,
extension: 'png',
})
}
// Preview should NOT open because the file was not found in limitedImages
expect(screen.queryByTestId('image-previewer')).not.toBeInTheDocument()
})
it('should handle single image', () => {
const images = createMockImages(1)
const { container } = render(<ImageList images={images} size="md" />)
expect(container.firstChild).toBeInTheDocument()
})
it('should not show More button when images count equals limit', () => {
const images = createMockImages(9)
render(<ImageList images={images} size="md" limit={9} />)
expect(screen.queryByText(/\+/)).not.toBeInTheDocument()
})
it('should handle limit of 0', () => {
const images = createMockImages(5)
render(<ImageList images={images} size="md" limit={0} />)
// Should show "+5" for all images
expect(screen.getByText(/\+5/)).toBeInTheDocument()
})
it('should handle limit larger than images count', () => {
const images = createMockImages(5)
render(<ImageList images={images} size="md" limit={100} />)
// Should not show More button
expect(screen.queryByText(/\+/)).not.toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,144 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import More from './more'
describe('More', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
render(<More count={5} />)
expect(screen.getByText('+5')).toBeInTheDocument()
})
it('should display count with plus sign', () => {
render(<More count={10} />)
expect(screen.getByText('+10')).toBeInTheDocument()
})
})
describe('Props', () => {
it('should format count as-is when less than 1000', () => {
render(<More count={999} />)
expect(screen.getByText('+999')).toBeInTheDocument()
})
it('should format count with k suffix when 1000 or more', () => {
render(<More count={1500} />)
expect(screen.getByText('+1.5k')).toBeInTheDocument()
})
it('should format count with M suffix when 1000000 or more', () => {
render(<More count={2500000} />)
expect(screen.getByText('+2.5M')).toBeInTheDocument()
})
it('should format 1000 as 1.0k', () => {
render(<More count={1000} />)
expect(screen.getByText('+1.0k')).toBeInTheDocument()
})
it('should format 1000000 as 1.0M', () => {
render(<More count={1000000} />)
expect(screen.getByText('+1.0M')).toBeInTheDocument()
})
})
describe('User Interactions', () => {
it('should call onClick when clicked', () => {
const onClick = vi.fn()
render(<More count={5} onClick={onClick} />)
fireEvent.click(screen.getByText('+5'))
expect(onClick).toHaveBeenCalledTimes(1)
})
it('should not throw when clicked without onClick', () => {
render(<More count={5} />)
// Should not throw
expect(() => {
fireEvent.click(screen.getByText('+5'))
}).not.toThrow()
})
it('should stop event propagation on click', () => {
const parentClick = vi.fn()
const childClick = vi.fn()
render(
<div onClick={parentClick}>
<More count={5} onClick={childClick} />
</div>,
)
fireEvent.click(screen.getByText('+5'))
expect(childClick).toHaveBeenCalled()
expect(parentClick).not.toHaveBeenCalled()
})
})
describe('Edge Cases', () => {
it('should display +0 when count is 0', () => {
render(<More count={0} />)
expect(screen.getByText('+0')).toBeInTheDocument()
})
it('should handle count of 1', () => {
render(<More count={1} />)
expect(screen.getByText('+1')).toBeInTheDocument()
})
it('should handle boundary value 999', () => {
render(<More count={999} />)
expect(screen.getByText('+999')).toBeInTheDocument()
})
it('should handle boundary value 999999', () => {
render(<More count={999999} />)
// 999999 / 1000 = 999.999 -> 1000.0k
expect(screen.getByText('+1000.0k')).toBeInTheDocument()
})
it('should apply cursor-pointer class', () => {
const { container } = render(<More count={5} />)
expect(container.firstChild).toHaveClass('cursor-pointer')
})
})
describe('formatNumber branches', () => {
it('should return "0" when num equals 0', () => {
// This covers line 11-12: if (num === 0) return '0'
render(<More count={0} />)
expect(screen.getByText('+0')).toBeInTheDocument()
})
it('should return num.toString() when num < 1000 and num > 0', () => {
// This covers line 13-14: if (num < 1000) return num.toString()
render(<More count={500} />)
expect(screen.getByText('+500')).toBeInTheDocument()
})
it('should return k format when 1000 <= num < 1000000', () => {
// This covers line 15-16
const { rerender } = render(<More count={5000} />)
expect(screen.getByText('+5.0k')).toBeInTheDocument()
rerender(<More count={999999} />)
expect(screen.getByText('+1000.0k')).toBeInTheDocument()
rerender(<More count={50000} />)
expect(screen.getByText('+50.0k')).toBeInTheDocument()
})
it('should return M format when num >= 1000000', () => {
// This covers line 17
const { rerender } = render(<More count={1000000} />)
expect(screen.getByText('+1.0M')).toBeInTheDocument()
rerender(<More count={5000000} />)
expect(screen.getByText('+5.0M')).toBeInTheDocument()
rerender(<More count={999999999} />)
expect(screen.getByText('+1000.0M')).toBeInTheDocument()
})
})
})

View File

@ -0,0 +1,525 @@
import { act, fireEvent, render, screen, waitFor } from '@testing-library/react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import ImagePreviewer from './index'
// Mock fetch
const mockFetch = vi.fn()
globalThis.fetch = mockFetch
// Mock URL methods
const mockRevokeObjectURL = vi.fn()
const mockCreateObjectURL = vi.fn(() => 'blob:mock-url')
globalThis.URL.revokeObjectURL = mockRevokeObjectURL
globalThis.URL.createObjectURL = mockCreateObjectURL
// Mock Image
class MockImage {
onload: (() => void) | null = null
onerror: (() => void) | null = null
_src = ''
get src() {
return this._src
}
set src(value: string) {
this._src = value
// Trigger onload after a microtask
setTimeout(() => {
if (this.onload)
this.onload()
}, 0)
}
naturalWidth = 800
naturalHeight = 600
}
;(globalThis as unknown as { Image: typeof MockImage }).Image = MockImage
const createMockImages = () => [
{ url: 'https://example.com/image1.png', name: 'image1.png', size: 1024 },
{ url: 'https://example.com/image2.png', name: 'image2.png', size: 2048 },
{ url: 'https://example.com/image3.png', name: 'image3.png', size: 3072 },
]
describe('ImagePreviewer', () => {
beforeEach(() => {
vi.clearAllMocks()
// Default successful fetch mock
mockFetch.mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(['test'], { type: 'image/png' })),
})
})
afterEach(() => {
vi.restoreAllMocks()
})
describe('Rendering', () => {
it('should render without crashing', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Should render in portal
expect(document.body.querySelector('.image-previewer')).toBeInTheDocument()
})
it('should render close button', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Esc text should be visible
expect(screen.getByText('Esc')).toBeInTheDocument()
})
it('should show loading state initially', async () => {
const onClose = vi.fn()
const images = createMockImages()
// Delay fetch to see loading state
mockFetch.mockImplementation(() => new Promise(() => {}))
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Loading component should be visible
expect(document.body.querySelector('.image-previewer')).toBeInTheDocument()
})
})
describe('Props', () => {
it('should start at initialIndex', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={1} onClose={onClose} />)
})
await waitFor(() => {
// Should start at second image
expect(screen.getByText('image2.png')).toBeInTheDocument()
})
})
it('should default initialIndex to 0', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
})
})
describe('User Interactions', () => {
it('should call onClose when close button is clicked', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Find and click close button (the one with RiCloseLine icon)
const closeButton = document.querySelector('.absolute.right-6 button')
if (closeButton) {
fireEvent.click(closeButton)
expect(onClose).toHaveBeenCalledTimes(1)
}
})
it('should navigate to next image when next button is clicked', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
// Find and click next button (right arrow)
const buttons = document.querySelectorAll('button')
const nextButton = Array.from(buttons).find(btn =>
btn.className.includes('right-8'),
)
if (nextButton) {
await act(async () => {
fireEvent.click(nextButton)
})
await waitFor(() => {
expect(screen.getByText('image2.png')).toBeInTheDocument()
})
}
})
it('should navigate to previous image when prev button is clicked', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={1} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText('image2.png')).toBeInTheDocument()
})
// Find and click prev button (left arrow)
const buttons = document.querySelectorAll('button')
const prevButton = Array.from(buttons).find(btn =>
btn.className.includes('left-8'),
)
if (prevButton) {
await act(async () => {
fireEvent.click(prevButton)
})
await waitFor(() => {
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
}
})
it('should disable prev button at first image', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={0} onClose={onClose} />)
})
const buttons = document.querySelectorAll('button')
const prevButton = Array.from(buttons).find(btn =>
btn.className.includes('left-8'),
)
expect(prevButton).toBeDisabled()
})
it('should disable next button at last image', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={2} onClose={onClose} />)
})
const buttons = document.querySelectorAll('button')
const nextButton = Array.from(buttons).find(btn =>
btn.className.includes('right-8'),
)
expect(nextButton).toBeDisabled()
})
})
describe('Image Loading', () => {
it('should fetch images on mount', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
expect(mockFetch).toHaveBeenCalled()
})
})
it('should show error state when fetch fails', async () => {
const onClose = vi.fn()
const images = createMockImages()
mockFetch.mockRejectedValue(new Error('Network error'))
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText(/Failed to load image/)).toBeInTheDocument()
})
})
it('should show retry button on error', async () => {
const onClose = vi.fn()
const images = createMockImages()
mockFetch.mockRejectedValue(new Error('Network error'))
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
// Retry button should be visible
const retryButton = document.querySelector('button.rounded-full')
expect(retryButton).toBeInTheDocument()
})
})
})
describe('Navigation Boundary Cases', () => {
it('should not navigate past first image when prevImage is called at index 0', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={0} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
// Click prev button multiple times - should stay at first image
const buttons = document.querySelectorAll('button')
const prevButton = Array.from(buttons).find(btn =>
btn.className.includes('left-8'),
)
if (prevButton) {
await act(async () => {
fireEvent.click(prevButton)
fireEvent.click(prevButton)
})
// Should still be at first image
await waitFor(() => {
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
}
})
it('should not navigate past last image when nextImage is called at last index', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} initialIndex={2} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText('image3.png')).toBeInTheDocument()
})
// Click next button multiple times - should stay at last image
const buttons = document.querySelectorAll('button')
const nextButton = Array.from(buttons).find(btn =>
btn.className.includes('right-8'),
)
if (nextButton) {
await act(async () => {
fireEvent.click(nextButton)
fireEvent.click(nextButton)
})
// Should still be at last image
await waitFor(() => {
expect(screen.getByText('image3.png')).toBeInTheDocument()
})
}
})
})
describe('Retry Functionality', () => {
it('should retry image load when retry button is clicked', async () => {
const onClose = vi.fn()
const images = createMockImages()
// First fail, then succeed
let callCount = 0
mockFetch.mockImplementation(() => {
callCount++
if (callCount === 1) {
return Promise.reject(new Error('Network error'))
}
return Promise.resolve({
ok: true,
blob: () => Promise.resolve(new Blob(['test'], { type: 'image/png' })),
})
})
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Wait for error state
await waitFor(() => {
expect(screen.getByText(/Failed to load image/)).toBeInTheDocument()
})
// Click retry button
const retryButton = document.querySelector('button.rounded-full')
if (retryButton) {
await act(async () => {
fireEvent.click(retryButton)
})
// Should refetch the image
await waitFor(() => {
expect(mockFetch).toHaveBeenCalledTimes(4) // 3 initial + 1 retry
})
}
})
it('should show retry button and call retryImage when clicked', async () => {
const onClose = vi.fn()
const images = createMockImages()
mockFetch.mockRejectedValue(new Error('Network error'))
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
expect(screen.getByText(/Failed to load image/)).toBeInTheDocument()
})
// Find and click the retry button (not the nav buttons)
const allButtons = document.querySelectorAll('button')
const retryButton = Array.from(allButtons).find(btn =>
btn.className.includes('rounded-full') && !btn.className.includes('left-8') && !btn.className.includes('right-8'),
)
expect(retryButton).toBeInTheDocument()
if (retryButton) {
mockFetch.mockClear()
mockFetch.mockResolvedValue({
ok: true,
blob: () => Promise.resolve(new Blob(['test'], { type: 'image/png' })),
})
await act(async () => {
fireEvent.click(retryButton)
})
await waitFor(() => {
expect(mockFetch).toHaveBeenCalled()
})
}
})
})
describe('Image Cache', () => {
it('should clean up blob URLs on unmount', async () => {
const onClose = vi.fn()
const images = createMockImages()
// First render to populate cache
const { unmount } = await act(async () => {
const result = render(<ImagePreviewer images={images} onClose={onClose} />)
return result
})
await waitFor(() => {
expect(mockFetch).toHaveBeenCalled()
})
// Store the call count for verification
const _firstCallCount = mockFetch.mock.calls.length
unmount()
// Note: The imageCache is cleared on unmount, so this test verifies
// the cleanup behavior rather than caching across mounts
expect(mockRevokeObjectURL).toHaveBeenCalled()
})
})
describe('Edge Cases', () => {
it('should handle single image', async () => {
const onClose = vi.fn()
const images = [createMockImages()[0]]
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
// Both navigation buttons should be disabled
const buttons = document.querySelectorAll('button')
const prevButton = Array.from(buttons).find(btn =>
btn.className.includes('left-8'),
)
const nextButton = Array.from(buttons).find(btn =>
btn.className.includes('right-8'),
)
expect(prevButton).toBeDisabled()
expect(nextButton).toBeDisabled()
})
it('should stop event propagation on container click', async () => {
const onClose = vi.fn()
const parentClick = vi.fn()
const images = createMockImages()
await act(async () => {
render(
<div onClick={parentClick}>
<ImagePreviewer images={images} onClose={onClose} />
</div>,
)
})
const container = document.querySelector('.image-previewer')
if (container) {
fireEvent.click(container)
expect(parentClick).not.toHaveBeenCalled()
}
})
it('should display image dimensions when loaded', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
// Should display dimensions (800 × 600 from MockImage)
expect(screen.getByText(/800.*600/)).toBeInTheDocument()
})
})
it('should display file size', async () => {
const onClose = vi.fn()
const images = createMockImages()
await act(async () => {
render(<ImagePreviewer images={images} onClose={onClose} />)
})
await waitFor(() => {
// Should display formatted file size
expect(screen.getByText('image1.png')).toBeInTheDocument()
})
})
})
})

View File

@ -0,0 +1,922 @@
import type { PropsWithChildren } from 'react'
import type { FileEntity } from '../types'
import { act, fireEvent, render, renderHook, screen, waitFor } from '@testing-library/react'
import * as React from 'react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import Toast from '@/app/components/base/toast'
import { FileContextProvider } from '../store'
import { useUpload } from './use-upload'
// Mock dependencies
vi.mock('@/service/use-common', () => ({
useFileUploadConfig: vi.fn(() => ({
data: {
image_file_batch_limit: 10,
single_chunk_attachment_limit: 20,
attachment_image_file_size_limit: 15,
},
})),
}))
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: vi.fn(),
},
}))
type FileUploadOptions = {
file: File
onProgressCallback?: (progress: number) => void
onSuccessCallback?: (res: { id: string, extension: string, mime_type: string, size: number }) => void
onErrorCallback?: (error?: Error) => void
}
const mockFileUpload = vi.fn<(options: FileUploadOptions) => void>()
const mockGetFileUploadErrorMessage = vi.fn(() => 'Upload error')
vi.mock('@/app/components/base/file-uploader/utils', () => ({
fileUpload: (options: FileUploadOptions) => mockFileUpload(options),
getFileUploadErrorMessage: () => mockGetFileUploadErrorMessage(),
}))
const createWrapper = () => {
return ({ children }: PropsWithChildren) => (
<FileContextProvider>
{children}
</FileContextProvider>
)
}
const createMockFile = (name = 'test.png', _size = 1024, type = 'image/png') => {
return new File(['test content'], name, { type })
}
// Mock FileReader
type EventCallback = () => void
class MockFileReader {
result: string | ArrayBuffer | null = null
onload: EventCallback | null = null
onerror: EventCallback | null = null
private listeners: Record<string, EventCallback[]> = {}
addEventListener(event: string, callback: EventCallback) {
if (!this.listeners[event])
this.listeners[event] = []
this.listeners[event].push(callback)
}
removeEventListener(event: string, callback: EventCallback) {
if (this.listeners[event])
this.listeners[event] = this.listeners[event].filter(cb => cb !== callback)
}
readAsDataURL(_file: File) {
setTimeout(() => {
this.result = 'data:image/png;base64,mockBase64Data'
this.listeners.load?.forEach(cb => cb())
}, 0)
}
triggerError() {
this.listeners.error?.forEach(cb => cb())
}
}
describe('useUpload hook', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFileUpload.mockImplementation(({ onSuccessCallback }) => {
setTimeout(() => {
onSuccessCallback?.({ id: 'uploaded-id', extension: 'png', mime_type: 'image/png', size: 1024 })
}, 0)
})
// Mock FileReader globally
vi.stubGlobal('FileReader', MockFileReader)
})
describe('Initialization', () => {
it('should initialize with default state', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(result.current.dragging).toBe(false)
expect(result.current.uploaderRef).toBeDefined()
expect(result.current.dragRef).toBeDefined()
expect(result.current.dropRef).toBeDefined()
})
it('should return file upload config', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(result.current.fileUploadConfig).toBeDefined()
expect(result.current.fileUploadConfig.imageFileBatchLimit).toBe(10)
expect(result.current.fileUploadConfig.singleChunkAttachmentLimit).toBe(20)
expect(result.current.fileUploadConfig.imageFileSizeLimit).toBe(15)
})
})
describe('File Operations', () => {
it('should expose selectHandle function', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(typeof result.current.selectHandle).toBe('function')
})
it('should expose fileChangeHandle function', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(typeof result.current.fileChangeHandle).toBe('function')
})
it('should expose handleRemoveFile function', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(typeof result.current.handleRemoveFile).toBe('function')
})
it('should expose handleReUploadFile function', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(typeof result.current.handleReUploadFile).toBe('function')
})
it('should expose handleLocalFileUpload function', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(typeof result.current.handleLocalFileUpload).toBe('function')
})
})
describe('File Validation', () => {
it('should show error toast for invalid file type', async () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
const mockEvent = {
target: {
files: [createMockFile('test.exe', 1024, 'application/x-msdownload')],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle(mockEvent)
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: expect.any(String),
})
})
})
it('should not reject valid image file types', async () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
const mockFile = createMockFile('test.png', 1024, 'image/png')
const mockEvent = {
target: {
files: [mockFile],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
// File type validation should pass for png files
// The actual upload will fail without proper FileReader mock,
// but we're testing that type validation doesn't reject valid files
act(() => {
result.current.fileChangeHandle(mockEvent)
})
// Should not show type error for valid image type
type ToastCall = [{ type: string, message: string }]
const mockNotify = vi.mocked(Toast.notify)
const calls = mockNotify.mock.calls as ToastCall[]
const typeErrorCalls = calls.filter(
(call: ToastCall) => call[0].type === 'error' && call[0].message.includes('Extension'),
)
expect(typeErrorCalls.length).toBe(0)
})
})
describe('Drag and Drop Refs', () => {
it('should provide dragRef', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(result.current.dragRef).toBeDefined()
expect(result.current.dragRef.current).toBeNull()
})
it('should provide dropRef', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(result.current.dropRef).toBeDefined()
expect(result.current.dropRef.current).toBeNull()
})
it('should provide uploaderRef', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(result.current.uploaderRef).toBeDefined()
expect(result.current.uploaderRef.current).toBeNull()
})
})
describe('Edge Cases', () => {
it('should handle empty file list', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
const mockEvent = {
target: {
files: [],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle(mockEvent)
})
// Should not throw and not show error
expect(Toast.notify).not.toHaveBeenCalled()
})
it('should handle null files', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
const mockEvent = {
target: {
files: null,
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle(mockEvent)
})
// Should not throw
expect(true).toBe(true)
})
it('should respect batch limit from config', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
// Config should have batch limit of 10
expect(result.current.fileUploadConfig.imageFileBatchLimit).toBe(10)
})
})
describe('File Size Validation', () => {
it('should show error for files exceeding size limit', async () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
// Create a file larger than 15MB limit (15 * 1024 * 1024 bytes)
const largeFile = new File(['x'.repeat(16 * 1024 * 1024)], 'large.png', { type: 'image/png' })
Object.defineProperty(largeFile, 'size', { value: 16 * 1024 * 1024 })
const mockEvent = {
target: {
files: [largeFile],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle(mockEvent)
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: expect.any(String),
})
})
})
})
describe('handleRemoveFile', () => {
it('should remove file from store', async () => {
const onChange = vi.fn()
const initialFiles: Partial<FileEntity>[] = [
{ id: 'file1', name: 'test1.png', progress: 100 },
{ id: 'file2', name: 'test2.png', progress: 100 },
]
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider value={initialFiles as FileEntity[]} onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
act(() => {
result.current.handleRemoveFile('file1')
})
expect(onChange).toHaveBeenCalledWith([
{ id: 'file2', name: 'test2.png', progress: 100 },
])
})
})
describe('handleReUploadFile', () => {
it('should re-upload file when called with valid fileId', async () => {
const onChange = vi.fn()
const initialFiles: Partial<FileEntity>[] = [
{ id: 'file1', name: 'test1.png', progress: -1, originalFile: new File(['test'], 'test1.png') },
]
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider value={initialFiles as FileEntity[]} onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
act(() => {
result.current.handleReUploadFile('file1')
})
await waitFor(() => {
expect(mockFileUpload).toHaveBeenCalled()
})
})
it('should not re-upload when fileId is not found', () => {
const onChange = vi.fn()
const initialFiles: Partial<FileEntity>[] = [
{ id: 'file1', name: 'test1.png', progress: -1, originalFile: new File(['test'], 'test1.png') },
]
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider value={initialFiles as FileEntity[]} onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
act(() => {
result.current.handleReUploadFile('nonexistent')
})
// fileUpload should not be called for nonexistent file
expect(mockFileUpload).not.toHaveBeenCalled()
})
it('should handle upload error during re-upload', async () => {
mockFileUpload.mockImplementation(({ onErrorCallback }: FileUploadOptions) => {
setTimeout(() => {
onErrorCallback?.(new Error('Upload failed'))
}, 0)
})
const onChange = vi.fn()
const initialFiles: Partial<FileEntity>[] = [
{ id: 'file1', name: 'test1.png', progress: -1, originalFile: new File(['test'], 'test1.png') },
]
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider value={initialFiles as FileEntity[]} onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
act(() => {
result.current.handleReUploadFile('file1')
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: 'Upload error',
})
})
})
})
describe('handleLocalFileUpload', () => {
it('should upload file and update progress', async () => {
mockFileUpload.mockImplementation(({ onProgressCallback, onSuccessCallback }: FileUploadOptions) => {
setTimeout(() => {
onProgressCallback?.(50)
setTimeout(() => {
onSuccessCallback?.({ id: 'uploaded-id', extension: 'png', mime_type: 'image/png', size: 1024 })
}, 10)
}, 0)
})
const onChange = vi.fn()
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
const mockFile = createMockFile('test.png', 1024, 'image/png')
await act(async () => {
result.current.handleLocalFileUpload(mockFile)
})
await waitFor(() => {
expect(mockFileUpload).toHaveBeenCalled()
})
})
it('should handle upload error', async () => {
mockFileUpload.mockImplementation(({ onErrorCallback }: FileUploadOptions) => {
setTimeout(() => {
onErrorCallback?.(new Error('Upload failed'))
}, 0)
})
const onChange = vi.fn()
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
const mockFile = createMockFile('test.png', 1024, 'image/png')
await act(async () => {
result.current.handleLocalFileUpload(mockFile)
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: 'Upload error',
})
})
})
})
describe('Attachment Limit', () => {
it('should show error when exceeding single chunk attachment limit', async () => {
const onChange = vi.fn()
// Pre-populate with 19 files (limit is 20)
const initialFiles: Partial<FileEntity>[] = Array.from({ length: 19 }, (_, i) => ({
id: `file${i}`,
name: `test${i}.png`,
progress: 100,
}))
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider value={initialFiles as FileEntity[]} onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
// Try to add 2 more files (would exceed limit of 20)
const mockEvent = {
target: {
files: [
createMockFile('new1.png'),
createMockFile('new2.png'),
],
},
} as unknown as React.ChangeEvent<HTMLInputElement>
act(() => {
result.current.fileChangeHandle(mockEvent)
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: expect.any(String),
})
})
})
})
describe('selectHandle', () => {
it('should trigger click on uploader input when called', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
// Create a mock input element
const mockInput = document.createElement('input')
const clickSpy = vi.spyOn(mockInput, 'click')
// Manually set the ref
Object.defineProperty(result.current.uploaderRef, 'current', {
value: mockInput,
writable: true,
})
act(() => {
result.current.selectHandle()
})
expect(clickSpy).toHaveBeenCalled()
})
it('should not throw when uploaderRef is null', () => {
const { result } = renderHook(() => useUpload(), {
wrapper: createWrapper(),
})
expect(() => {
act(() => {
result.current.selectHandle()
})
}).not.toThrow()
})
})
describe('FileReader Error Handling', () => {
it('should show error toast when FileReader encounters an error', async () => {
// Create a custom MockFileReader that triggers error
class ErrorFileReader {
result: string | ArrayBuffer | null = null
private listeners: Record<string, EventCallback[]> = {}
addEventListener(event: string, callback: EventCallback) {
if (!this.listeners[event])
this.listeners[event] = []
this.listeners[event].push(callback)
}
removeEventListener(event: string, callback: EventCallback) {
if (this.listeners[event])
this.listeners[event] = this.listeners[event].filter(cb => cb !== callback)
}
readAsDataURL(_file: File) {
// Trigger error instead of load
setTimeout(() => {
this.listeners.error?.forEach(cb => cb())
}, 0)
}
}
vi.stubGlobal('FileReader', ErrorFileReader)
const onChange = vi.fn()
const wrapper = ({ children }: PropsWithChildren) => (
<FileContextProvider onChange={onChange}>
{children}
</FileContextProvider>
)
const { result } = renderHook(() => useUpload(), { wrapper })
const mockFile = createMockFile('test.png', 1024, 'image/png')
await act(async () => {
result.current.handleLocalFileUpload(mockFile)
})
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: expect.any(String),
})
})
// Restore original MockFileReader
vi.stubGlobal('FileReader', MockFileReader)
})
})
describe('Drag and Drop Functionality', () => {
// Test component that renders the hook with actual DOM elements
const TestComponent = ({ onStateChange }: { onStateChange?: (dragging: boolean) => void }) => {
const { dragging, dragRef, dropRef } = useUpload()
// Report dragging state changes to parent
React.useEffect(() => {
onStateChange?.(dragging)
}, [dragging, onStateChange])
return (
<div ref={dropRef} data-testid="drop-zone">
<div ref={dragRef} data-testid="drag-boundary">
<span data-testid="dragging-state">{dragging ? 'dragging' : 'not-dragging'}</span>
</div>
</div>
)
}
it('should set dragging to true on dragEnter when target is not dragRef', async () => {
const onStateChange = vi.fn()
render(
<FileContextProvider>
<TestComponent onStateChange={onStateChange} />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// Fire dragenter event on dropZone (not dragRef)
await act(async () => {
fireEvent.dragEnter(dropZone, {
dataTransfer: { items: [] },
})
})
// Verify dragging state changed to true
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
})
it('should set dragging to false on dragLeave when target matches dragRef', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
const dragBoundary = screen.getByTestId('drag-boundary')
// First trigger dragenter to set dragging to true
await act(async () => {
fireEvent.dragEnter(dropZone, {
dataTransfer: { items: [] },
})
})
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
// Then trigger dragleave on dragBoundary to set dragging to false
await act(async () => {
fireEvent.dragLeave(dragBoundary, {
dataTransfer: { items: [] },
})
})
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
})
it('should handle drop event with files and reset dragging state', async () => {
const onChange = vi.fn()
render(
<FileContextProvider onChange={onChange}>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
const mockFile = new File(['test content'], 'test.png', { type: 'image/png' })
// First trigger dragenter
await act(async () => {
fireEvent.dragEnter(dropZone, {
dataTransfer: { items: [] },
})
})
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
// Then trigger drop with files
await act(async () => {
fireEvent.drop(dropZone, {
dataTransfer: {
items: [{
webkitGetAsEntry: () => null,
getAsFile: () => mockFile,
}],
},
})
})
// Dragging should be reset to false after drop
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
})
it('should return early when dataTransfer is null on drop', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// Fire dragenter first
await act(async () => {
fireEvent.dragEnter(dropZone)
})
// Fire drop without dataTransfer
await act(async () => {
fireEvent.drop(dropZone)
})
// Should still reset dragging state
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
})
it('should not trigger file upload for invalid file types on drop', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
const invalidFile = new File(['test'], 'test.exe', { type: 'application/x-msdownload' })
await act(async () => {
fireEvent.drop(dropZone, {
dataTransfer: {
items: [{
webkitGetAsEntry: () => null,
getAsFile: () => invalidFile,
}],
},
})
})
// Should show error toast for invalid file type
await waitFor(() => {
expect(Toast.notify).toHaveBeenCalledWith({
type: 'error',
message: expect.any(String),
})
})
})
it('should handle drop with webkitGetAsEntry for file entries', async () => {
const onChange = vi.fn()
const mockFile = new File(['test'], 'test.png', { type: 'image/png' })
render(
<FileContextProvider onChange={onChange}>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// Create a mock file entry that simulates webkitGetAsEntry behavior
const mockFileEntry = {
isFile: true,
isDirectory: false,
file: (callback: (file: File) => void) => callback(mockFile),
}
await act(async () => {
fireEvent.drop(dropZone, {
dataTransfer: {
items: [{
webkitGetAsEntry: () => mockFileEntry,
getAsFile: () => mockFile,
}],
},
})
})
// Dragging should be reset
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
})
})
describe('Drag Events', () => {
const TestComponent = () => {
const { dragging, dragRef, dropRef } = useUpload()
return (
<div ref={dropRef} data-testid="drop-zone">
<div ref={dragRef} data-testid="drag-boundary">
<span data-testid="dragging-state">{dragging ? 'dragging' : 'not-dragging'}</span>
</div>
</div>
)
}
it('should handle dragEnter event and update dragging state', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// Initially not dragging
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
// Fire dragEnter
await act(async () => {
fireEvent.dragEnter(dropZone, {
dataTransfer: { items: [] },
})
})
// Should be dragging now
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
})
it('should handle dragOver event without changing state', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// First trigger dragenter to set dragging
await act(async () => {
fireEvent.dragEnter(dropZone)
})
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
// dragOver should not change the dragging state
await act(async () => {
fireEvent.dragOver(dropZone)
})
// Should still be dragging
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
})
it('should not set dragging to true when dragEnter target is dragRef', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dragBoundary = screen.getByTestId('drag-boundary')
// Fire dragEnter directly on dragRef
await act(async () => {
fireEvent.dragEnter(dragBoundary)
})
// Should not be dragging when target is dragRef itself
expect(screen.getByTestId('dragging-state')).toHaveTextContent('not-dragging')
})
it('should not set dragging to false when dragLeave target is not dragRef', async () => {
render(
<FileContextProvider>
<TestComponent />
</FileContextProvider>,
)
const dropZone = screen.getByTestId('drop-zone')
// First trigger dragenter on dropZone to set dragging
await act(async () => {
fireEvent.dragEnter(dropZone)
})
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
// dragLeave on dropZone (not dragRef) should not change dragging state
await act(async () => {
fireEvent.dragLeave(dropZone)
})
// Should still be dragging (only dragLeave on dragRef resets)
expect(screen.getByTestId('dragging-state')).toHaveTextContent('dragging')
})
})
})

View File

@ -0,0 +1,107 @@
import { fireEvent, render, screen } from '@testing-library/react'
import { describe, expect, it, vi } from 'vitest'
import { FileContextProvider } from '../store'
import ImageInput from './image-input'
// Mock dependencies
vi.mock('@/service/use-common', () => ({
useFileUploadConfig: vi.fn(() => ({
data: {
image_file_batch_limit: 10,
single_chunk_attachment_limit: 20,
attachment_image_file_size_limit: 15,
},
})),
}))
const renderWithProvider = (ui: React.ReactElement) => {
return render(
<FileContextProvider>
{ui}
</FileContextProvider>,
)
}
describe('ImageInput (image-uploader-in-chunk)', () => {
describe('Rendering', () => {
it('should render without crashing', () => {
const { container } = renderWithProvider(<ImageInput />)
expect(container.firstChild).toBeInTheDocument()
})
it('should render file input element', () => {
renderWithProvider(<ImageInput />)
const input = document.querySelector('input[type="file"]')
expect(input).toBeInTheDocument()
})
it('should have hidden file input', () => {
renderWithProvider(<ImageInput />)
const input = document.querySelector('input[type="file"]')
expect(input).toHaveClass('hidden')
})
it('should render upload icon', () => {
const { container } = renderWithProvider(<ImageInput />)
const icon = container.querySelector('svg')
expect(icon).toBeInTheDocument()
})
it('should render browse text', () => {
renderWithProvider(<ImageInput />)
expect(screen.getByText(/browse/i)).toBeInTheDocument()
})
})
describe('File Input Props', () => {
it('should accept multiple files', () => {
renderWithProvider(<ImageInput />)
const input = document.querySelector('input[type="file"]')
expect(input).toHaveAttribute('multiple')
})
it('should have accept attribute for images', () => {
renderWithProvider(<ImageInput />)
const input = document.querySelector('input[type="file"]')
expect(input).toHaveAttribute('accept')
})
})
describe('User Interactions', () => {
it('should open file dialog when browse is clicked', () => {
renderWithProvider(<ImageInput />)
const browseText = screen.getByText(/browse/i)
const input = document.querySelector('input[type="file"]') as HTMLInputElement
const clickSpy = vi.spyOn(input, 'click')
fireEvent.click(browseText)
expect(clickSpy).toHaveBeenCalled()
})
})
describe('Drag and Drop', () => {
it('should have drop zone area', () => {
const { container } = renderWithProvider(<ImageInput />)
// The drop zone has dashed border styling
expect(container.querySelector('.border-dashed')).toBeInTheDocument()
})
it('should apply accent styles when dragging', () => {
// This would require simulating drag events
// Just verify the base structure exists
const { container } = renderWithProvider(<ImageInput />)
expect(container.querySelector('.border-components-dropzone-border')).toBeInTheDocument()
})
})
describe('Edge Cases', () => {
it('should display file size limit from config', () => {
renderWithProvider(<ImageInput />)
// The tip text should contain the size limit (15 from mock)
const tipText = document.querySelector('.system-xs-regular')
expect(tipText).toBeInTheDocument()
})
})
})

Some files were not shown because too many files have changed in this diff Show More