diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index ef9e9c103a..1d439323f2 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -3,8 +3,8 @@ from __future__ import annotations import base64 import json import logging -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Mapping +from typing import Any, cast from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError @@ -17,6 +17,7 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType @@ -46,6 +47,7 @@ class MCPTool(Tool): self.headers = headers or {} self.timeout = timeout self.sse_read_timeout = sse_read_timeout + self._latest_usage = LLMUsage.empty_usage() def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.MCP @@ -59,6 +61,10 @@ class MCPTool(Tool): message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: result = self.invoke_remote_mcp_tool(tool_parameters) + + # Extract usage metadata from MCP protocol's _meta field + self._latest_usage = self._derive_usage_from_result(result) + # handle dify tool output for content in result.content: if isinstance(content, TextContent): @@ -120,6 +126,99 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) + @property + def latest_usage(self) -> LLMUsage: + return self._latest_usage + + @classmethod + def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage: + """ + Extract usage metadata from MCP tool result's _meta field. + + The MCP protocol's _meta field (aliased as 'meta' in Python) can contain + usage information such as token counts, costs, and other metadata. + + Args: + result: The CallToolResult from MCP tool invocation + + Returns: + LLMUsage instance with values from meta or empty_usage if not found + """ + # Extract usage from the meta field if present + if result.meta: + usage_dict = cls._extract_usage_dict(result.meta) + if usage_dict is not None: + return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict)))) + + return LLMUsage.empty_usage() + + @classmethod + def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None: + """ + Recursively search for usage dictionary in the payload. + + The MCP protocol's _meta field can contain usage data in various formats: + - Direct usage field: {"usage": {...}} + - Nested in metadata: {"metadata": {"usage": {...}}} + - Or nested within other fields + + Args: + payload: The payload to search for usage data + + Returns: + The usage dictionary if found, None otherwise + """ + # Check for direct usage field + usage_candidate = payload.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + # Check for metadata nested usage + metadata_candidate = payload.get("metadata") + if isinstance(metadata_candidate, Mapping): + usage_candidate = metadata_candidate.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + # Check for common token counting fields directly in payload + # Some MCP servers may include token counts directly + if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload: + usage_dict: dict[str, Any] = {} + for key in ( + "prompt_tokens", + "completion_tokens", + "total_tokens", + "prompt_unit_price", + "completion_unit_price", + "total_price", + "currency", + "prompt_price_unit", + "completion_price_unit", + "prompt_price", + "completion_price", + "latency", + "time_to_first_token", + "time_to_generate", + ): + if key in payload: + usage_dict[key] = payload[key] + if usage_dict: + return usage_dict + + # Recursively search through nested structures + for value in payload.values(): + if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found + return None + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, diff --git a/api/pyproject.toml b/api/pyproject.toml index 4be7afff26..2a7c946e6e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -81,7 +81,7 @@ dependencies = [ "starlette==0.49.1", "tiktoken~=0.9.0", "transformers~=4.56.1", - "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", + "unstructured[docx,epub,md,ppt,pptx]~=0.18.18", "yarl~=1.18.3", "webvtt-py~=0.5.1", "sseclient-py~=1.8.0", diff --git a/api/services/async_workflow_service.py b/api/services/async_workflow_service.py index bc73b7c8c2..94452482b3 100644 --- a/api/services/async_workflow_service.py +++ b/api/services/async_workflow_service.py @@ -155,11 +155,11 @@ class AsyncWorkflowService: task: AsyncResult[Any] | None = None if queue_name == QueuePriority.PROFESSIONAL: - task = execute_workflow_professional.delay(task_data_dict) # type: ignore + task = execute_workflow_professional.delay(task_data_dict) elif queue_name == QueuePriority.TEAM: - task = execute_workflow_team.delay(task_data_dict) # type: ignore + task = execute_workflow_team.delay(task_data_dict) else: # SANDBOX - task = execute_workflow_sandbox.delay(task_data_dict) # type: ignore + task = execute_workflow_sandbox.delay(task_data_dict) # 10. Update trigger log with task info trigger_log.status = WorkflowTriggerStatus.QUEUED @@ -170,7 +170,7 @@ class AsyncWorkflowService: return AsyncTriggerResponse( workflow_trigger_log_id=trigger_log.id, - task_id=task.id, # type: ignore + task_id=task.id, status="queued", queue=queue_name, ) diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 23c49f2742..a9a8b892c2 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -58,5 +57,3 @@ def add_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index e928c25546..432732af95 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,7 +5,6 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception: logger.exception("Annotation deleted index failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 957d8f7e45..6ff34c0e74 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -59,5 +58,3 @@ def update_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 8ee09d5738..f69f17b16d 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -48,6 +48,11 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" + # Initialize variables with default values + upload_file_key: str | None = None + dataset_config: dict | None = None + document_config: dict | None = None + with session_factory.create_session() as session: try: dataset = session.get(Dataset, dataset_id) @@ -69,86 +74,115 @@ def batch_create_segment_to_index_task( if not upload_file: raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + dataset_config = { + "id": dataset.id, + "indexing_technique": dataset.indexing_technique, + "tenant_id": dataset.tenant_id, + "embedding_model_provider": dataset.embedding_model_provider, + "embedding_model": dataset.embedding_model, + } - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): - if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + document_config = { + "id": dataset_document.id, + "doc_form": dataset_document.doc_form, + "word_count": dataset_document.word_count or 0, + } - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + upload_file_key = upload_file.key - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] - ) + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") + return + + # Ensure required variables are set before proceeding + if upload_file_key is None or dataset_config is None or document_config is None: + logger.error("Required configuration not set due to session error") + redis_client.setex(indexing_cache_key, 600, "error") + return + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file_key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file_key, file_path) + + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if document_config["doc_form"] == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} else: - tokens_list = [0] * len(content) + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - session.add(segment_document) - document_segments.append(segment_document) + document_segments = [] + embedding_model = None + if dataset_config["indexing_technique"] == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset_config["tenant_id"], + provider=dataset_config["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=dataset_config["embedding_model"], + ) + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content]) + else: + tokens_list = [0] * len(content) + + with session_factory.create_session() as session, session.begin(): + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document_config["id"]) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) + if document_config["doc_form"] == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + + with session_factory.create_session() as session, session.begin(): + dataset_document = session.get(Document, document_id) + if dataset_document: assert dataset_document.word_count is not None dataset_document.word_count += word_count_change session.add(dataset_document) - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") + with session_factory.create_session() as session: + dataset = session.get(Dataset, dataset_id) + if dataset: + VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"]) + + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 91ace6be02..a017e9114b 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i """ logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() + total_attachment_files = [] with session_factory.create_session() as session: try: @@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i SegmentAttachmentBinding.document_id == document_id, ) ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() + + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings]) + + index_node_ids = [segment.index_node_id for segment in segments] + segment_contents = [segment.content for segment in segments] + except Exception: + logger.exception("Cleaned document when document deleted failed") + return + + # check segment is exist + if index_node_ids: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.scalars( - select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - ).all() - for image_file in image_files: - if image_file is None: - continue - try: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) + total_image_files = [] + with session_factory.create_session() as session, session.begin(): + for segment_content in segment_contents: + image_upload_file_ids = get_image_upload_file_ids(segment_content) + image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all() + total_image_files.extend([image_file.key for image_file in image_files]) + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) - image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(image_file_delete_stmt) - session.delete(segment) + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) - session.commit() - if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - session.delete(file) - # delete segment attachments - if attachments_with_bindings: - attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] - binding_ids = [binding.id for binding, _ in attachments_with_bindings] - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) - session.execute(attachment_file_delete_stmt) - - binding_delete_stmt = delete(SegmentAttachmentBinding).where( - SegmentAttachmentBinding.id.in_(binding_ids) - ) - session.execute(binding_delete_stmt) - - # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", - ) - ) + for image_file_key in total_image_files: + try: + storage.delete(image_file_key) except Exception: - logger.exception("Cleaned document when document deleted failed") + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: + try: + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + + with session_factory.create_session() as session, session.begin(): + # delete segment attachments + if attachment_ids: + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) + + if binding_ids: + binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids)) + session.execute(binding_delete_stmt) + + for attachment_file_key in total_attachment_files: + try: + storage.delete(attachment_file_key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + attachment_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 34496e9c6f..11edcf151f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.commit() return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) + # Phase 1: Update status to parsing (short transaction) + with session_factory.create_session() as session, session.begin(): + documents = ( + session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + ) + for document in documents: if document: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) session.add(document) - session.commit() + # Transaction committed and closed - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions) + has_error = False + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + has_error = True + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) + has_error = True + if not has_error: + with session_factory.create_session() as session: # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) @@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # expire all session to get latest document's indexing status session.expire_all() # Check each document's indexing status and trigger summary generation if completed - for document_id in document_ids: - # Re-query document to get latest status (IndexingRunner may have updated it) - document = ( - session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() - ) + + documents = ( + session.query(Document) + .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + .all() + ) + + for document in documents: if document: logger.info( "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, @@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): and document.need_summary is True ): try: - generate_summary_index_task.delay(dataset.id, document_id, None) + generate_summary_index_task.delay(dataset.id, document.id, None) logger.info( "Queued summary index generation task for document %s in dataset %s " "after indexing completed", - document_id, + document.id, dataset.id, ) except Exception: logger.exception( "Failed to queue summary index generation task for document %s", - document_id, + document.id, ) # Don't fail the entire indexing process if summary task queuing fails else: logger.info( "Skipping summary generation for document %s: " "status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, ) else: - logger.warning("Document %s not found after indexing", document_id) - else: - logger.info( - "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", - dataset.id, - summary_index_setting.get("enable") if summary_index_setting else None, - ) + logger.warning("Document %s not found after indexing", document.id) else: logger.info( "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", dataset.id, dataset.indexing_technique, ) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 45d58c92ec..c7508c6d05 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -36,25 +36,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - # delete all document segment and index - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + return - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) + index_type = document.doc_form + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + clean_success = False + try: + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if index_node_ids: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() logger.info( click.style( @@ -64,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str): fg="green", ) ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + clean_success = True + except Exception: + logger.exception("Failed to clean document index during update, document_id: %s", document_id) - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) + if clean_success: + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/tasks/workflow_draft_var_tasks.py b/api/tasks/workflow_draft_var_tasks.py index fcb98ec39e..26f8f7c29e 100644 --- a/api/tasks/workflow_draft_var_tasks.py +++ b/api/tasks/workflow_draft_var_tasks.py @@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers. """ from celery import shared_task # type: ignore[import-untyped] -from sqlalchemy.orm import Session -from extensions.ext_database import db +from core.db.session_factory import session_factory from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService @@ -17,6 +16,6 @@ def save_workflow_execution_task( self, deletions: list[DraftVarFileDeletion], ): - with Session(bind=db.engine) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): srv = WorkflowDraftVariableService(session=session) srv.delete_workflow_draft_variable_file(deletions=deletions) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 1b844d6357..61f6b75b10 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask: mock_storage.download.side_effect = mock_download - # Execute the task + # Execute the task - should raise ValueError for empty CSV job_id = str(uuid.uuid4()) - batch_create_segment_to_index_task( - job_id=job_id, - upload_file_id=upload_file.id, - dataset_id=dataset.id, - document_id=document.id, - tenant_id=tenant.id, - user_id=account.id, - ) + with pytest.raises(ValueError, match="The CSV file is empty"): + batch_create_segment_to_index_task( + job_id=job_id, + upload_file_id=upload_file.id, + dataset_id=dataset.id, + document_id=document.id, + tenant_id=tenant.id, + user_id=account.id, + ) # Verify error handling - # Check Redis cache was set to error status - from extensions.ext_redis import redis_client - - cache_key = f"segment_batch_import_{job_id}" - cache_value = redis_client.get(cache_key) - assert cache_value == b"error" - - # Verify no segments were created + # Since exception was raised, no segments should be created from extensions.ext_database import db segments = db.session.query(DocumentSegment).all() diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index e24ef32a24..8d8e2b0db0 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id): def mock_db_session(): """Mock database session via session_factory.create_session().""" with patch("tasks.document_indexing_task.session_factory") as mock_sf: - session = MagicMock() - # Ensure tests that expect session.close() to be called can observe it via the context manager - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - # Link __exit__ to session.close so "close" expectations reflect context manager teardown + sessions = [] # Track all created sessions + # Shared mock data that all sessions will access + shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None} - def _exit_side_effect(*args, **kwargs): - session.close() + def create_session_side_effect(): + session = MagicMock() + session.close = MagicMock() - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm + # Track commit calls + commit_mock = MagicMock() + session.commit = commit_mock + cm = MagicMock() + cm.__enter__.return_value = session - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - yield session + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + + # Support session.begin() for transactions + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def begin_exit_side_effect(*args, **kwargs): + # Auto-commit on transaction exit (like SQLAlchemy) + session.commit() + # Also mark wrapper's commit as called + if sessions: + sessions[0].commit() + + begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect) + session.begin = MagicMock(return_value=begin_cm) + + sessions.append(session) + + # Setup query with side_effect to handle both Dataset and Document queries + def query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + # Support both .first() and .all() calls with chaining + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + # Create an iterator for .first() calls if not exists + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + session.query = MagicMock(side_effect=query_side_effect) + return cm + + mock_sf.create_session.side_effect = create_session_side_effect + + # Create a wrapper that behaves like the first session but has access to all sessions + class SessionWrapper: + def __init__(self): + self._sessions = sessions + self._shared_data = shared_mock_data + # Create a default session for setup phase + self._default_session = MagicMock() + self._default_session.close = MagicMock() + self._default_session.commit = MagicMock() + + # Support session.begin() for default session too + begin_cm = MagicMock() + begin_cm.__enter__.return_value = self._default_session + + def default_begin_exit_side_effect(*args, **kwargs): + self._default_session.commit() + + begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect) + self._default_session.begin = MagicMock(return_value=begin_cm) + + def default_query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + self._default_session.query = MagicMock(side_effect=default_query_side_effect) + + def __getattr__(self, name): + # Forward all attribute access to the first session, or default if none created yet + target_session = self._sessions[0] if self._sessions else self._default_session + return getattr(target_session, name) + + @property + def all_sessions(self): + """Access all created sessions for testing.""" + return self._sessions + + wrapper = SessionWrapper() + yield wrapper @pytest.fixture @@ -252,18 +356,9 @@ class TestTaskEnqueuing: use the deprecated function. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -304,21 +399,9 @@ class TestBatchProcessing: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create an iterator for documents - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -357,19 +440,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL @@ -407,19 +480,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX @@ -444,7 +507,10 @@ class TestBatchProcessing: """ # Arrange document_ids = [] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -482,19 +548,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -528,19 +584,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -635,19 +681,9 @@ class TestErrorHandling: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set up to trigger vector space limit error mock_feature_service.get_features.return_value.billing.enabled = True @@ -674,17 +710,9 @@ class TestErrorHandling: Errors during indexing should be caught and logged, but not crash the task. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Indexing failed") @@ -708,17 +736,9 @@ class TestErrorHandling: but not treated as a failure. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise DocumentIsPausedError mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") @@ -853,17 +873,9 @@ class TestTaskCancellation: Session cleanup should happen in finally block. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -883,17 +895,9 @@ class TestTaskCancellation: Session cleanup should happen even when errors occur. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Test error") @@ -962,6 +966,7 @@ class TestAdvancedScenarios: document_ids = [str(uuid.uuid4()) for _ in range(3)] # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents mock_documents = [] for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one doc = MagicMock(spec=Document) @@ -971,21 +976,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create iterator that returns None for missing document - doc_responses = [mock_documents[0], None, mock_documents[1]] - doc_iter = iter(doc_responses) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1075,19 +1068,9 @@ class TestAdvancedScenarios: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space exactly at limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1219,19 +1202,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Billing disabled - limits should not be checked mock_feature_service.get_features.return_value.billing.enabled = False @@ -1273,19 +1246,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1321,19 +1284,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1415,17 +1368,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1465,17 +1410,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1555,19 +1492,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space limit to 0 (unlimited) mock_feature_service.get_features.return_value.billing.enabled = True @@ -1612,19 +1539,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set negative vector space limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1675,19 +1592,9 @@ class TestPerformanceScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Configure billing with sufficient limits mock_feature_service.get_features.return_value.billing.enabled = True @@ -1826,19 +1733,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") @@ -1866,7 +1763,7 @@ class TestRobustness: - No exceptions occur Expected behavior: - - Database session is closed + - All database sessions are closed - No connection leaks """ # Arrange @@ -1879,19 +1776,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1899,10 +1786,11 @@ class TestRobustness: # Act _document_indexing(dataset_id, document_ids) - # Assert - assert mock_db_session.close.called - # Verify close is called exactly once - assert mock_db_session.close.call_count == 1 + # Assert - All created sessions should be closed + # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary) + assert len(mock_db_session.all_sessions) >= 1 + for session in mock_db_session.all_sessions: + assert session.close.called, "All sessions should be closed" def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): """ diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index a527773e4e..5930b63f58 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -1,4 +1,5 @@ import base64 +from decimal import Decimal from unittest.mock import Mock, patch import pytest @@ -9,8 +10,10 @@ from core.mcp.types import ( CallToolResult, EmbeddedResource, ImageContent, + TextContent, TextResourceContents, ) +from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage @@ -120,3 +123,231 @@ class TestMCPToolInvoke: # Validate values values = {m.message.variable_name: m.message.variable_value for m in var_msgs} assert values == {"a": 1, "b": "x"} + + +class TestMCPToolUsageExtraction: + """Test usage metadata extraction from MCP tool results.""" + + def test_extract_usage_dict_from_direct_usage_field(self) -> None: + """Test extraction when usage is directly in meta.usage field.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "total_price": "0.001", + "currency": "USD", + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 100 + assert usage_dict["completion_tokens"] == 50 + assert usage_dict["total_tokens"] == 150 + assert usage_dict["total_price"] == "0.001" + assert usage_dict["currency"] == "USD" + + def test_extract_usage_dict_from_nested_metadata(self) -> None: + """Test extraction when usage is nested in meta.metadata.usage.""" + meta = { + "metadata": { + "usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + } + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 200 + assert usage_dict["total_tokens"] == 300 + + def test_extract_usage_dict_from_flat_token_fields(self) -> None: + """Test extraction when token counts are directly in meta.""" + meta = { + "prompt_tokens": 150, + "completion_tokens": 75, + "total_tokens": 225, + "currency": "EUR", + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 150 + assert usage_dict["completion_tokens"] == 75 + assert usage_dict["total_tokens"] == 225 + assert usage_dict["currency"] == "EUR" + + def test_extract_usage_dict_recursive(self) -> None: + """Test recursive search through nested structures.""" + meta = { + "custom": { + "nested": { + "usage": { + "total_tokens": 500, + "prompt_tokens": 300, + "completion_tokens": 200, + } + } + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["total_tokens"] == 500 + + def test_extract_usage_dict_from_list(self) -> None: + """Test extraction from nested list structures.""" + meta = { + "items": [ + {"usage": {"total_tokens": 100}}, + {"other": "data"}, + ] + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["total_tokens"] == 100 + + def test_extract_usage_dict_returns_none_when_missing(self) -> None: + """Test that None is returned when no usage data is present.""" + meta = {"other": "data", "custom": {"nested": {"value": 123}}} + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is None + + def test_extract_usage_dict_empty_meta(self) -> None: + """Test with empty meta dict.""" + usage_dict = MCPTool._extract_usage_dict({}) + assert usage_dict is None + + def test_derive_usage_from_result_with_meta(self) -> None: + """Test _derive_usage_from_result with populated meta.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "total_price": "0.0015", + "currency": "USD", + } + } + result = CallToolResult(content=[], _meta=meta) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.total_price == Decimal("0.0015") + assert usage.currency == "USD" + + def test_derive_usage_from_result_without_meta(self) -> None: + """Test _derive_usage_from_result with no meta returns empty usage.""" + result = CallToolResult(content=[], meta=None) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + assert usage.total_tokens == 0 + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + + def test_derive_usage_from_result_calculates_total_tokens(self) -> None: + """Test that total_tokens is calculated when missing.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + # total_tokens is missing + } + } + result = CallToolResult(content=[], _meta=meta) + usage = MCPTool._derive_usage_from_result(result) + + assert usage.total_tokens == 150 # 100 + 50 + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + + def test_invoke_sets_latest_usage_from_meta(self) -> None: + """Test that _invoke sets _latest_usage from result meta.""" + tool = _make_mcp_tool() + meta = { + "usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + "total_price": "0.003", + "currency": "USD", + } + } + result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=meta) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + list(tool._invoke(user_id="test_user", tool_parameters={})) + + # Verify latest_usage was set correctly + assert tool.latest_usage.prompt_tokens == 200 + assert tool.latest_usage.completion_tokens == 100 + assert tool.latest_usage.total_tokens == 300 + assert tool.latest_usage.total_price == Decimal("0.003") + + def test_invoke_with_no_meta_returns_empty_usage(self) -> None: + """Test that _invoke returns empty usage when no meta is present.""" + tool = _make_mcp_tool() + result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=None) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + list(tool._invoke(user_id="test_user", tool_parameters={})) + + # Verify latest_usage is empty + assert tool.latest_usage.total_tokens == 0 + assert tool.latest_usage.prompt_tokens == 0 + assert tool.latest_usage.completion_tokens == 0 + + def test_latest_usage_property_returns_llm_usage(self) -> None: + """Test that latest_usage property returns LLMUsage instance.""" + tool = _make_mcp_tool() + assert isinstance(tool.latest_usage, LLMUsage) + + def test_initial_usage_is_empty(self) -> None: + """Test that MCPTool is initialized with empty usage.""" + tool = _make_mcp_tool() + assert tool.latest_usage.total_tokens == 0 + assert tool.latest_usage.prompt_tokens == 0 + assert tool.latest_usage.completion_tokens == 0 + assert tool.latest_usage.total_price == Decimal(0) + + @pytest.mark.parametrize( + "meta_data", + [ + # Direct usage field + {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}, + # Nested metadata + {"metadata": {"usage": {"total_tokens": 100}}}, + # Flat token fields + {"total_tokens": 50, "prompt_tokens": 30, "completion_tokens": 20}, + # With price info + { + "usage": { + "total_tokens": 150, + "total_price": "0.002", + "currency": "EUR", + } + }, + # Deep nested + {"level1": {"level2": {"usage": {"total_tokens": 200}}}}, + ], + ) + def test_various_meta_formats(self, meta_data) -> None: + """Test that various meta formats are correctly parsed.""" + result = CallToolResult(content=[], _meta=meta_data) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + # Should have at least some usage data + if meta_data.get("usage", {}).get("total_tokens") or meta_data.get("total_tokens"): + expected_total = ( + meta_data.get("usage", {}).get("total_tokens") + or meta_data.get("total_tokens") + or meta_data.get("metadata", {}).get("usage", {}).get("total_tokens") + or meta_data.get("level1", {}).get("level2", {}).get("usage", {}).get("total_tokens") + ) + if expected_total: + assert usage.total_tokens == expected_total diff --git a/api/uv.lock b/api/uv.lock index 0a17741f9a..4eb5c42659 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1653,7 +1653,7 @@ requires-dist = [ { name = "starlette", specifier = "==0.49.1" }, { name = "tiktoken", specifier = "~=0.9.0" }, { name = "transformers", specifier = "~=4.56.1" }, - { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, + { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.17.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -6814,12 +6814,12 @@ wheels = [ [[package]] name = "unstructured" -version = "0.16.25" +version = "0.18.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, { name = "beautifulsoup4" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "dataclasses-json" }, { name = "emoji" }, { name = "filetype" }, @@ -6827,6 +6827,7 @@ dependencies = [ { name = "langdetect" }, { name = "lxml" }, { name = "nltk" }, + { name = "numba" }, { name = "numpy" }, { name = "psutil" }, { name = "python-iso639" }, @@ -6839,9 +6840,9 @@ dependencies = [ { name = "unstructured-client" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" }, + { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" }, ] [package.optional-dependencies] diff --git a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx index 7fbb1bc7c4..69032b4743 100644 --- a/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx +++ b/web/app/components/app/configuration/dataset-config/params-config/config-content.tsx @@ -194,11 +194,11 @@ const ConfigContent: FC = ({ {type === RETRIEVE_TYPE.multiWay && ( <> -
-
+
+
{t('rerankSettings', { ns: 'dataset' })}
- +
{ selectedDatasetsMode.inconsistentEmbeddingModel diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 6aba41d4e4..8db964cc27 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -308,7 +308,7 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: }, [plugins, collectionPlugins, exclude]) return { - plugins: allPlugins, + plugins: searchText ? plugins : allPlugins, isLoading: isCollectionLoading || isPluginsLoading, } }