diff --git a/api/commands.py b/api/commands.py index c4f2c9edbb..75b17df78e 100644 --- a/api/commands.py +++ b/api/commands.py @@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from extensions.storage.opendal_storage import OpenDALStorage from extensions.storage.storage_type import StorageType +from libs.db_migration_lock import DbMigrationAutoRenewLock from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair @@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) +DB_UPGRADE_LOCK_TTL_SECONDS = 60 + @click.command("reset-password", help="Reset the account password.") @click.option("--email", prompt=True, help="Account email to reset password for") @@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No @click.command("upgrade-db", help="Upgrade the database") def upgrade_db(): click.echo("Preparing database migration...") - lock = redis_client.lock(name="db_upgrade_lock", timeout=60) + lock = DbMigrationAutoRenewLock( + redis_client=redis_client, + name="db_upgrade_lock", + ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, + logger=logger, + log_context="db_migration", + ) if lock.acquire(blocking=False): + migration_succeeded = False try: click.echo(click.style("Starting database migration.", fg="green")) @@ -737,12 +747,16 @@ def upgrade_db(): flask_migrate.upgrade() + migration_succeeded = True click.echo(click.style("Database migration successful!", fg="green")) - except Exception: + except Exception as e: logger.exception("Failed to execute database migration") + click.echo(click.style(f"Database migration failed: {e}", fg="red")) + raise SystemExit(1) finally: - lock.release() + status = "successful" if migration_succeeded else "failed" + lock.release_safely(status=status) else: click.echo("Database migration skipped") diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8c371da596..42901ab590 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,4 @@ +import logging import uuid from datetime import datetime from typing import Any, Literal, TypeAlias @@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co register_enum_models(console_ns, IconType) +_logger = logging.getLogger(__name__) + class AppListQuery(BaseModel): page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") @@ -499,6 +502,7 @@ class AppListApi(Resource): select(Workflow).where( Workflow.version == Workflow.VERSION_DRAFT, Workflow.app_id.in_(workflow_capable_app_ids), + Workflow.tenant_id == current_tenant_id, ) ) .scalars() @@ -510,12 +514,14 @@ class AppListApi(Resource): NodeType.TRIGGER_PLUGIN, } for workflow in draft_workflows: + node_id = None try: - for _, node_data in workflow.walk_nodes(): + for node_id, node_data in workflow.walk_nodes(): if node_data.get("type") in trigger_node_types: draft_trigger_app_ids.add(str(workflow.app_id)) break except Exception: + _logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id) continue for app in app_pagination.items: @@ -654,6 +660,19 @@ class AppCopyApi(Resource): ) session.commit() + # Inherit web app permission from original app + if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: + try: + # Get the original app's access mode + original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id) + access_mode = original_settings.access_mode + except Exception: + # If original app has no settings (old app), default to public to match fallback behavior + access_mode = "public" + + # Apply the same access mode to the copied app + EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode) + stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index 5bfa895849..f5de6709dd 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource): current_user, current_tenant_id = current_account_with_tenant() payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) return BuiltinToolManageService.set_default_provider( - tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id + tenant_id=current_tenant_id, + user_id=current_user.id, + provider=provider, + id=args["id"], + account=current_user, ) diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index 2f6f5cc5db..08d3dec770 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -45,6 +45,8 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk +from core.file import helpers as file_helpers +from core.file.enums import FileTransferMethod from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( @@ -57,10 +59,11 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import emit as telemetry_emit +from core.tools.signature import sign_tool_file from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now -from models.model import AppMode, Conversation, Message, MessageAgentThought +from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile logger = logging.getLogger(__name__) @@ -473,6 +476,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): metadata=metadata_dict, ) + def _record_files(self): + with Session(db.engine, expire_on_commit=False) as session: + message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all() + if not message_files: + return None + + files_list = [] + upload_file_ids = [ + mf.upload_file_id + for mf in message_files + if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id + ] + upload_files_map = {} + if upload_file_ids: + upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all() + upload_files_map = {uf.id: uf for uf in upload_files} + + for message_file in message_files: + upload_file = None + if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id: + upload_file = upload_files_map.get(message_file.upload_file_id) + + url = None + filename = "file" + mime_type = "application/octet-stream" + size = 0 + extension = "" + + if message_file.transfer_method == FileTransferMethod.REMOTE_URL: + url = message_file.url + if message_file.url: + filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params + elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE: + if upload_file: + url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id)) + filename = upload_file.name + mime_type = upload_file.mime_type or "application/octet-stream" + size = upload_file.size or 0 + extension = f".{upload_file.extension}" if upload_file.extension else "" + elif message_file.upload_file_id: + # Fallback: generate URL even if upload_file not found + url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id)) + elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url: + # For tool files, use URL directly if it's HTTP, otherwise sign it + if message_file.url.startswith("http"): + url = message_file.url + filename = message_file.url.split("/")[-1].split("?")[0] + else: + # Extract tool file id and extension from URL + url_parts = message_file.url.split("/") + if url_parts: + file_part = url_parts[-1].split("?")[0] # Remove query params first + # Use rsplit to correctly handle filenames with multiple dots + if "." in file_part: + tool_file_id, ext = file_part.rsplit(".", 1) + extension = f".{ext}" + else: + tool_file_id = file_part + extension = ".bin" + url = sign_tool_file(tool_file_id=tool_file_id, extension=extension) + filename = file_part + + transfer_method_value = message_file.transfer_method + remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else "" + file_dict = { + "related_id": message_file.id, + "extension": extension, + "filename": filename, + "size": size, + "mime_type": mime_type, + "transfer_method": transfer_method_value, + "type": message_file.type, + "url": url or "", + "upload_file_id": message_file.upload_file_id or message_file.id, + "remote_url": remote_url, + } + files_list.append(file_dict) + return files_list or None + def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: """ Agent message to stream response. diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py index 2d4ee08daf..2b37436983 100644 --- a/api/core/app/task_pipeline/message_cycle_manager.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -64,7 +64,13 @@ class MessageCycleManager: # Use SQLAlchemy 2.x style session.scalar(select(...)) with session_factory.create_session() as session: - message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) + message_file = session.scalar( + select(MessageFile) + .where( + MessageFile.message_id == message_id, + ) + .where(MessageFile.belongs_to == "assistant") + ) if message_file: self._message_has_file.add(message_id) diff --git a/api/libs/db_migration_lock.py b/api/libs/db_migration_lock.py new file mode 100644 index 0000000000..1d3a81e0a2 --- /dev/null +++ b/api/libs/db_migration_lock.py @@ -0,0 +1,213 @@ +""" +DB migration Redis lock with heartbeat renewal. + +This is intentionally migration-specific. Background renewal is a trade-off that makes sense +for unbounded, blocking operations like DB migrations (DDL/DML) where the main thread cannot +periodically refresh the lock TTL. + +Do NOT use this as a general-purpose lock primitive for normal application code. Prefer explicit +lock lifecycle management (e.g. redis-py Lock context manager + `extend()` / `reacquire()` from +the same thread) when execution flow is under control. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from redis.exceptions import LockNotOwnedError, RedisError + +logger = logging.getLogger(__name__) + +MIN_RENEW_INTERVAL_SECONDS = 0.1 +DEFAULT_RENEW_INTERVAL_DIVISOR = 3 +MIN_JOIN_TIMEOUT_SECONDS = 0.5 +MAX_JOIN_TIMEOUT_SECONDS = 5.0 +JOIN_TIMEOUT_MULTIPLIER = 2.0 + + +class DbMigrationAutoRenewLock: + """ + Redis lock wrapper that automatically renews TTL while held (migration-only). + + Notes: + - We force `thread_local=False` when creating the underlying redis-py lock, because the + lock token must be accessible from the heartbeat thread for `reacquire()` to work. + - `release_safely()` is best-effort: it never raises, so it won't mask the caller's + primary error/exit code. + """ + + _redis_client: Any + _name: str + _ttl_seconds: float + _renew_interval_seconds: float + _log_context: str | None + _logger: logging.Logger + + _lock: Any + _stop_event: threading.Event | None + _thread: threading.Thread | None + _acquired: bool + + def __init__( + self, + redis_client: Any, + name: str, + ttl_seconds: float = 60, + renew_interval_seconds: float | None = None, + *, + logger: logging.Logger | None = None, + log_context: str | None = None, + ) -> None: + self._redis_client = redis_client + self._name = name + self._ttl_seconds = float(ttl_seconds) + self._renew_interval_seconds = ( + float(renew_interval_seconds) + if renew_interval_seconds is not None + else max(MIN_RENEW_INTERVAL_SECONDS, self._ttl_seconds / DEFAULT_RENEW_INTERVAL_DIVISOR) + ) + self._logger = logger or logging.getLogger(__name__) + self._log_context = log_context + + self._lock = None + self._stop_event = None + self._thread = None + self._acquired = False + + @property + def name(self) -> str: + return self._name + + def acquire(self, *args: Any, **kwargs: Any) -> bool: + """ + Acquire the lock and start heartbeat renewal on success. + + Accepts the same args/kwargs as redis-py `Lock.acquire()`. + """ + # Prevent accidental double-acquire which could leave the previous heartbeat thread running. + if self._acquired: + raise RuntimeError("DB migration lock is already acquired; call release_safely() before acquiring again.") + + # Reuse the lock object if we already created one. + if self._lock is None: + self._lock = self._redis_client.lock( + name=self._name, + timeout=self._ttl_seconds, + thread_local=False, + ) + acquired = bool(self._lock.acquire(*args, **kwargs)) + self._acquired = acquired + if acquired: + self._start_heartbeat() + return acquired + + def owned(self) -> bool: + if self._lock is None: + return False + try: + return bool(self._lock.owned()) + except Exception: + # Ownership checks are best-effort and must not break callers. + return False + + def _start_heartbeat(self) -> None: + if self._lock is None: + return + if self._stop_event is not None: + return + + self._stop_event = threading.Event() + self._thread = threading.Thread( + target=self._heartbeat_loop, + args=(self._lock, self._stop_event), + daemon=True, + name=f"DbMigrationAutoRenewLock({self._name})", + ) + self._thread.start() + + def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + while not stop_event.wait(self._renew_interval_seconds): + try: + lock.reacquire() + except LockNotOwnedError: + self._logger.warning( + "DB migration lock is no longer owned during heartbeat; stop renewing. log_context=%s", + self._log_context, + exc_info=True, + ) + return + except RedisError: + self._logger.warning( + "Failed to renew DB migration lock due to Redis error; will retry. log_context=%s", + self._log_context, + exc_info=True, + ) + except Exception: + self._logger.warning( + "Unexpected error while renewing DB migration lock; will retry. log_context=%s", + self._log_context, + exc_info=True, + ) + + def release_safely(self, *, status: str | None = None) -> None: + """ + Stop heartbeat and release lock. Never raises. + + Args: + status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs. + """ + lock = self._lock + if lock is None: + return + + self._stop_heartbeat() + + # Lock release errors should never mask the real error/exit code. + try: + lock.release() + except LockNotOwnedError: + self._logger.warning( + "DB migration lock not owned on release; ignoring. status=%s log_context=%s", + status, + self._log_context, + exc_info=True, + ) + except RedisError: + self._logger.warning( + "Failed to release DB migration lock due to Redis error; ignoring. status=%s log_context=%s", + status, + self._log_context, + exc_info=True, + ) + except Exception: + self._logger.warning( + "Unexpected error while releasing DB migration lock; ignoring. status=%s log_context=%s", + status, + self._log_context, + exc_info=True, + ) + finally: + self._acquired = False + self._lock = None + + def _stop_heartbeat(self) -> None: + if self._stop_event is None: + return + self._stop_event.set() + if self._thread is not None: + # Best-effort join: if Redis calls are blocked, the daemon thread may remain alive. + join_timeout_seconds = max( + MIN_JOIN_TIMEOUT_SECONDS, + min(MAX_JOIN_TIMEOUT_SECONDS, self._renew_interval_seconds * JOIN_TIMEOUT_MULTIPLIER), + ) + self._thread.join(timeout=join_timeout_seconds) + if self._thread.is_alive(): + self._logger.warning( + "DB migration lock heartbeat thread did not stop within %.2fs; ignoring. log_context=%s", + join_timeout_seconds, + self._log_context, + ) + self._stop_event = None + self._thread = None diff --git a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py index e89fcee7e5..0d42de6a3a 100644 --- a/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py +++ b/api/migrations/versions/2025_12_25_1039-7df29de0f6be_add_credit_pool.py @@ -8,7 +8,6 @@ Create Date: 2025-12-25 10:39:15.139304 from alembic import op import models as models import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = '7df29de0f6be' @@ -20,7 +19,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('tenant_credit_pools', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False), sa.Column('quota_limit', sa.BigInteger(), nullable=False), diff --git a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py index b99ca04e3f..52672e8db6 100644 --- a/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py +++ b/api/migrations/versions/2026_01_17_1110-f9f6d18a37f9_add_table_explore_banner_and_trial.py @@ -8,7 +8,6 @@ Create Date: 2026-01-017 11:10:18.079355 from alembic import op import models as models import sqlalchemy as sa -from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = 'f9f6d18a37f9' @@ -20,7 +19,7 @@ depends_on = None def upgrade(): # ### commands auto generated by Alembic - please adjust! ### op.create_table('account_trial_app_records', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('account_id', models.types.StringUUID(), nullable=False), sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('count', sa.Integer(), nullable=False), @@ -33,17 +32,17 @@ def upgrade(): batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False) op.create_table('exporle_banners', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('content', sa.JSON(), nullable=False), sa.Column('link', sa.String(length=255), nullable=False), sa.Column('sort', sa.Integer(), nullable=False), - sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False), + sa.Column('status', sa.String(length=255), server_default='enabled', nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), - sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False), + sa.Column('language', sa.String(length=255), server_default='en-US', nullable=False), sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey') ) op.create_table('trial_apps', - sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('id', models.types.StringUUID(), nullable=False), sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), diff --git a/api/models/model.py b/api/models/model.py index c1c6e04ce9..429c46bd85 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -620,7 +620,7 @@ class TrialApp(Base): sa.UniqueConstraint("app_id", name="unique_trail_app_id"), ) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, default=lambda: str(uuid4())) app_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) @@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base): sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), ) - id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) + id = mapped_column(StringUUID, default=lambda: str(uuid4())) account_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False) count = mapped_column(sa.Integer, nullable=False, default=0) @@ -660,18 +660,18 @@ class AccountTrialAppRecord(Base): class ExporleBanner(TypeBase): __tablename__ = "exporle_banners" __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) - id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) + id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) status: Mapped[str] = mapped_column( - sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" + sa.String(255), nullable=False, server_default='enabled', default="enabled" ) created_at: Mapped[datetime] = mapped_column( sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False ) language: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US" + String(255), nullable=False, server_default='en-US', default="en-US" ) @@ -2166,7 +2166,7 @@ class TenantCreditPool(TypeBase): sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"), ) - id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False) + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) diff --git a/api/pyproject.toml b/api/pyproject.toml index c05e884271..2a7c946e6e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.12.0" +version = "1.12.1" requires-python = ">=3.11,<3.13" dependencies = [ @@ -81,7 +81,7 @@ dependencies = [ "starlette==0.49.1", "tiktoken~=0.9.0", "transformers~=4.56.1", - "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", + "unstructured[docx,epub,md,ppt,pptx]~=0.18.18", "yarl~=1.18.3", "webvtt-py~=0.5.1", "sseclient-py~=1.8.0", diff --git a/api/services/account_service.py b/api/services/account_service.py index 35e4a505af..8f8604f0f3 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -327,6 +327,12 @@ class AccountService: @staticmethod def delete_account(account: Account): """Delete account. This method only adds a task to the queue for deletion.""" + # Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only) + from services.enterprise.account_deletion_sync import sync_account_deletion + + sync_account_deletion(account_id=account.id, source="account_deleted") + + # Now proceed with async account deletion delete_account_task.delay(account.id) @staticmethod @@ -1230,6 +1236,11 @@ class TenantService: if dify_config.BILLING_ENABLED: BillingService.clean_billing_info_cache(tenant.id) + # Queue account deletion sync task for enterprise backend to reassign resources (enterprise only) + from services.enterprise.account_deletion_sync import sync_workspace_member_removal + + sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed") + @staticmethod def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): """Update member role""" diff --git a/api/services/enterprise/account_deletion_sync.py b/api/services/enterprise/account_deletion_sync.py new file mode 100644 index 0000000000..f8f8189891 --- /dev/null +++ b/api/services/enterprise/account_deletion_sync.py @@ -0,0 +1,115 @@ +import json +import logging +import uuid +from datetime import UTC, datetime + +from redis import RedisError + +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin + +logger = logging.getLogger(__name__) + +ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue" +ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace" + + +def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Queue an account deletion sync task to Redis. + + Internal helper function. Do not call directly - use the public functions instead. + + Args: + workspace_id: The workspace/tenant ID to sync + member_id: The member/account ID that was removed + source: Source of the sync request (for debugging/tracking) + + Returns: + bool: True if task was queued successfully, False otherwise + """ + try: + task = { + "task_id": str(uuid.uuid4()), + "workspace_id": workspace_id, + "member_id": member_id, + "retry_count": 0, + "created_at": datetime.now(UTC).isoformat(), + "source": source, + "type": ACCOUNT_DELETION_SYNC_TASK_TYPE, + } + + # Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP + redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task)) + + logger.info( + "Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s", + workspace_id, + member_id, + task["task_id"], + source, + ) + return True + + except (RedisError, TypeError) as e: + logger.error( + "Failed to queue account deletion sync for workspace %s, member %s: %s", + workspace_id, + member_id, + str(e), + exc_info=True, + ) + # Don't raise - we don't want to fail member deletion if queueing fails + return False + + +def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool: + """ + Sync a single workspace member removal (enterprise only). + + Queues a task for the enterprise backend to reassign resources from the removed member. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + workspace_id: The workspace/tenant ID + member_id: The member/account ID that was removed + source: Source of the sync request (e.g., "workspace_member_removed") + + Returns: + bool: True if task was queued (or skipped in community), False if queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) + + +def sync_account_deletion(account_id: str, *, source: str) -> bool: + """ + Sync full account deletion across all workspaces (enterprise only). + + Fetches all workspace memberships for the account and queues a sync task for each. + Handles enterprise edition check internally. Safe to call in community edition (no-op). + + Args: + account_id: The account ID being deleted + source: Source of the sync request (e.g., "account_deleted") + + Returns: + bool: True if all tasks were queued (or skipped in community), False if any queueing failed + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + + # Fetch all workspaces the account belongs to + workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all() + + # Queue sync task for each workspace + success = True + for join in workspace_joins: + if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source): + success = False + + return success diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index a5133dfcb4..9930c6bf7c 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -4,6 +4,8 @@ from pydantic import BaseModel, Field from services.enterprise.base import EnterpriseRequest +ALLOWED_ACCESS_MODES = ["public", "private", "private_all", "sso_verified"] + class WebAppSettings(BaseModel): access_mode: str = Field( @@ -123,8 +125,8 @@ class EnterpriseService: def update_app_access_mode(cls, app_id: str, access_mode: str): if not app_id: raise ValueError("app_id must be provided.") - if access_mode not in ["public", "private", "private_all"]: - raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + if access_mode not in ALLOWED_ACCESS_MODES: + raise ValueError(f"access_mode must be one of: {', '.join(ALLOWED_ACCESS_MODES)}") data = {"appId": app_id, "accessMode": access_mode} diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67dde..f895e88e6b 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -2,7 +2,10 @@ import json import logging from collections.abc import Mapping from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from models.account import Account from sqlalchemy import exists, select from sqlalchemy.orm import Session @@ -406,20 +409,37 @@ class BuiltinToolManageService: return {"result": "success"} @staticmethod - def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None): """ set default provider """ with Session(db.engine) as session: - # get provider - target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + # get provider (verify tenant ownership to prevent IDOR) + target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first() if target_provider is None: raise ValueError("provider not found") # clear default provider - session.query(BuiltinToolProvider).filter_by( - tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True - ).update({"is_default": False}) + if dify_config.ENTERPRISE_ENABLED: + # Enterprise: verify admin permission for tenant-wide operation + from models.account import TenantAccountRole + + if account is None: + # In enterprise mode, an account context is required to perform permission checks + raise ValueError("Account is required to set default credentials in enterprise mode") + + if not TenantAccountRole.is_privileged_role(account.current_role): + raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode") + # Enterprise: clear ALL defaults for this provider in the tenant + # (regardless of user_id, since enterprise credentials may have different user_id) + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, provider=provider, is_default=True + ).update({"is_default": False}) + else: + # Non-enterprise: only clear defaults for the current user + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True + ).update({"is_default": False}) # set new default provider target_provider.is_default = True diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 23c49f2742..a9a8b892c2 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -58,5 +57,3 @@ def add_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index e928c25546..432732af95 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,7 +5,6 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception: logger.exception("Annotation deleted index failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 957d8f7e45..6ff34c0e74 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -59,5 +58,3 @@ def update_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index d388284980..747106d373 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -14,6 +14,9 @@ from models.model import UploadFile logger = logging.getLogger(__name__) +# Batch size for database operations to keep transactions short +BATCH_SIZE = 1000 + @shared_task(queue="dataset") def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): @@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not doc_form: raise ValueError("doc_form is required") - with session_factory.create_session() as session: - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - - if not dataset: - raise Exception("Document has no dataset") - - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ).delete(synchronize_session=False) + storage_keys_to_delete: list[str] = [] + index_node_ids: list[str] = [] + segment_ids: list[str] = [] + total_image_upload_file_ids: list[str] = [] + try: + # ============ Step 1: Query segment and file data (short read-only transaction) ============ + with session_factory.create_session() as session: + # Get segments info segments = session.scalars( select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) ).all() - # check segment is exist + if segments: index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + segment_ids = [segment.id for segment in segments] + # Collect image file IDs from segment content for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() - for image_file in image_files: - try: - if image_file and image_file.key: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) - stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(stmt) - session.delete(segment) + total_image_upload_file_ids.extend(image_upload_file_ids) + + # Query storage keys for image files + if total_image_upload_file_ids: + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids)) + ).all() + storage_keys_to_delete.extend([f.key for f in image_files if f and f.key]) + + # Query storage keys for document files if file_ids: files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() - for file in files: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file.id) - stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) - session.execute(stmt) + storage_keys_to_delete.extend([f.key for f in files if f and f.key]) - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned documents when documents deleted latency: {end_at - start_at}", - fg="green", + # ============ Step 2: Clean vector index (external service, fresh session for dataset) ============ + if index_node_ids: + try: + # Fetch dataset in a fresh session to avoid DetachedInstanceError + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id) + else: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d", + dataset_id, + document_ids, + len(index_node_ids), ) - ) + + # ============ Step 3: Delete metadata binding (separate short transaction) ============ + try: + with session_factory.create_session() as session: + deleted_count = ( + session.query(DatasetMetadataBinding) + .where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + .delete(synchronize_session=False) + ) + session.commit() + logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: - logger.exception("Cleaned documents when documents deleted failed") + logger.exception( + "Failed to delete metadata bindings for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) + + # ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============ + if total_image_upload_file_ids: + failed_batches = 0 + total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE): + batch = total_image_upload_file_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(batch)) + session.execute(stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete image UploadFile batch %d-%d for dataset_id: %s", + i, + i + len(batch), + dataset_id, + ) + if failed_batches > 0: + logger.warning( + "Image UploadFile deletion: %d/%d batches failed for dataset_id: %s", + failed_batches, + total_batches, + dataset_id, + ) + + # ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============ + if segment_ids: + failed_batches = 0 + total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(segment_ids), BATCH_SIZE): + batch = segment_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch)) + session.execute(segment_delete_stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s", + i, + i + len(batch), + dataset_id, + document_ids, + ) + if failed_batches > 0: + logger.warning( + "DocumentSegment deletion: %d/%d batches failed, document_ids: %s", + failed_batches, + total_batches, + document_ids, + ) + + # ============ Step 6: Delete document-associated files (separate short transaction) ============ + if file_ids: + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(stmt) + session.commit() + except Exception: + logger.exception( + "Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s", + dataset_id, + file_ids, + ) + + # ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============ + storage_delete_failures = 0 + for storage_key in storage_keys_to_delete: + try: + storage.delete(storage_key) + except Exception: + storage_delete_failures += 1 + logger.exception("Failed to delete file from storage, key: %s", storage_key) + if storage_delete_failures > 0: + logger.warning( + "Storage file deletion completed with %d failures out of %d total files for dataset_id: %s", + storage_delete_failures, + len(storage_keys_to_delete), + dataset_id, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, " + f"dataset_id: {dataset_id}, document_ids: {document_ids}, " + f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, " + f"storage_files: {len(storage_keys_to_delete)}", + fg="green", + ) + ) + except Exception: + logger.exception( + "Batch clean documents failed for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 8ee09d5738..f69f17b16d 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -48,6 +48,11 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" + # Initialize variables with default values + upload_file_key: str | None = None + dataset_config: dict | None = None + document_config: dict | None = None + with session_factory.create_session() as session: try: dataset = session.get(Dataset, dataset_id) @@ -69,86 +74,115 @@ def batch_create_segment_to_index_task( if not upload_file: raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + dataset_config = { + "id": dataset.id, + "indexing_technique": dataset.indexing_technique, + "tenant_id": dataset.tenant_id, + "embedding_model_provider": dataset.embedding_model_provider, + "embedding_model": dataset.embedding_model, + } - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): - if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + document_config = { + "id": dataset_document.id, + "doc_form": dataset_document.doc_form, + "word_count": dataset_document.word_count or 0, + } - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + upload_file_key = upload_file.key - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] - ) + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") + return + + # Ensure required variables are set before proceeding + if upload_file_key is None or dataset_config is None or document_config is None: + logger.error("Required configuration not set due to session error") + redis_client.setex(indexing_cache_key, 600, "error") + return + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file_key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file_key, file_path) + + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if document_config["doc_form"] == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} else: - tokens_list = [0] * len(content) + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - session.add(segment_document) - document_segments.append(segment_document) + document_segments = [] + embedding_model = None + if dataset_config["indexing_technique"] == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset_config["tenant_id"], + provider=dataset_config["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=dataset_config["embedding_model"], + ) + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content]) + else: + tokens_list = [0] * len(content) + + with session_factory.create_session() as session, session.begin(): + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document_config["id"]) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) + if document_config["doc_form"] == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + + with session_factory.create_session() as session, session.begin(): + dataset_document = session.get(Document, document_id) + if dataset_document: assert dataset_document.word_count is not None dataset_document.word_count += word_count_change session.add(dataset_document) - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") + with session_factory.create_session() as session: + dataset = session.get(Dataset, dataset_id) + if dataset: + VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"]) + + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 91ace6be02..a017e9114b 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i """ logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() + total_attachment_files = [] with session_factory.create_session() as session: try: @@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i SegmentAttachmentBinding.document_id == document_id, ) ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() + + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings]) + + index_node_ids = [segment.index_node_id for segment in segments] + segment_contents = [segment.content for segment in segments] + except Exception: + logger.exception("Cleaned document when document deleted failed") + return + + # check segment is exist + if index_node_ids: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.scalars( - select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - ).all() - for image_file in image_files: - if image_file is None: - continue - try: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) + total_image_files = [] + with session_factory.create_session() as session, session.begin(): + for segment_content in segment_contents: + image_upload_file_ids = get_image_upload_file_ids(segment_content) + image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all() + total_image_files.extend([image_file.key for image_file in image_files]) + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) - image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(image_file_delete_stmt) - session.delete(segment) + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) - session.commit() - if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - session.delete(file) - # delete segment attachments - if attachments_with_bindings: - attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] - binding_ids = [binding.id for binding, _ in attachments_with_bindings] - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) - session.execute(attachment_file_delete_stmt) - - binding_delete_stmt = delete(SegmentAttachmentBinding).where( - SegmentAttachmentBinding.id.in_(binding_ids) - ) - session.execute(binding_delete_stmt) - - # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", - ) - ) + for image_file_key in total_image_files: + try: + storage.delete(image_file_key) except Exception: - logger.exception("Cleaned document when document deleted failed") + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: + try: + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + + with session_factory.create_session() as session, session.begin(): + # delete segment attachments + if attachment_ids: + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) + + if binding_ids: + binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids)) + session.execute(binding_delete_stmt) + + for attachment_file_key in total_attachment_files: + try: + storage.delete(attachment_file_key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + attachment_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py index 4214f043e0..c22ee761d8 100644 --- a/api/tasks/clean_notion_document_task.py +++ b/api/tasks/clean_notion_document_task.py @@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str): """ logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green")) start_at = time.perf_counter() + total_index_node_ids = [] with session_factory.create_session() as session: - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Document has no dataset") - index_type = dataset.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() + if not dataset: + raise Exception("Document has no dataset") + index_type = dataset.doc_form + index_processor = IndexProcessorFactory(index_type).init_index_processor() - document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) - session.execute(document_delete_stmt) + document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) + session.execute(document_delete_stmt) - for document_id in document_ids: - segments = session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] + for document_id in document_ids: + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + total_index_node_ids.extend([segment.index_node_id for segment in segments]) - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) - session.commit() - end_at = time.perf_counter() - logger.info( - click.style( - "Clean document when import form notion document deleted end :: {} latency: {}".format( - dataset_id, end_at - start_at - ), - fg="green", - ) + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: + index_processor.clean( + dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True ) - except Exception: - logger.exception("Cleaned document when import form notion document deleted failed") + + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) + session.execute(segment_delete_stmt) + + end_at = time.perf_counter() + logger.info( + click.style( + "Clean document when import form notion document deleted end :: {} latency: {}".format( + dataset_id, end_at - start_at + ), + fg="green", + ) + ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 764c635d83..a6a2dcebc8 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -67,8 +68,14 @@ def delete_segment_from_index_task( if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) - for binding in segment_attachment_bindings: - session.delete(binding) + segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings] + + for i in range(0, len(segment_attachment_bind_ids), 1000): + segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000]) + ) + session.execute(segment_attachment_bind_delete_stmt) + # delete upload file session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 149185f6e2..45b44438e7 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): """ logger.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() + tenant_id = None - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: logger.info(click.style(f"Document not found: {document_id}", fg="red")) return + if document.indexing_status == "parsing": + logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow")) + return + + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + raise Exception("Dataset not found") + data_source_info = document.data_source_info_dict - if document.data_source_type == "notion_import": - if ( - not data_source_info - or "notion_page_id" not in data_source_info - or "notion_workspace_id" not in data_source_info - ): - raise ValueError("no notion page found") - workspace_id = data_source_info["notion_workspace_id"] - page_id = data_source_info["notion_page_id"] - page_type = data_source_info["type"] - page_edited_time = data_source_info["last_edited_time"] - credential_id = data_source_info.get("credential_id") + if document.data_source_type != "notion_import": + logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow")) + return - # Get credentials from datasource provider - datasource_provider_service = DatasourceProviderService() - credential = datasource_provider_service.get_datasource_credentials( - tenant_id=document.tenant_id, - credential_id=credential_id, - provider="notion_datasource", - plugin_id="langgenius/notion_datasource", - ) + if ( + not data_source_info + or "notion_page_id" not in data_source_info + or "notion_workspace_id" not in data_source_info + ): + raise ValueError("no notion page found") - if not credential: - logger.error( - "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", - document_id, - document.tenant_id, - credential_id, - ) + workspace_id = data_source_info["notion_workspace_id"] + page_id = data_source_info["notion_page_id"] + page_type = data_source_info["type"] + page_edited_time = data_source_info["last_edited_time"] + credential_id = data_source_info.get("credential_id") + tenant_id = document.tenant_id + index_type = document.doc_form + + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + + # Get credentials from datasource provider + datasource_provider_service = DatasourceProviderService() + credential = datasource_provider_service.get_datasource_credentials( + tenant_id=tenant_id, + credential_id=credential_id, + provider="notion_datasource", + plugin_id="langgenius/notion_datasource", + ) + + if not credential: + logger.error( + "Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", + document_id, + tenant_id, + credential_id, + ) + + with session_factory.create_session() as session, session.begin(): + document = session.query(Document).filter_by(id=document_id).first() + if document: document.indexing_status = "error" document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() - session.commit() - return + return - loader = NotionExtractor( - notion_workspace_id=workspace_id, - notion_obj_id=page_id, - notion_page_type=page_type, - notion_access_token=credential.get("integration_secret"), - tenant_id=document.tenant_id, - ) + loader = NotionExtractor( + notion_workspace_id=workspace_id, + notion_obj_id=page_id, + notion_page_type=page_type, + notion_access_token=credential.get("integration_secret"), + tenant_id=tenant_id, + ) - last_edited_time = loader.get_notion_last_edited_time() + last_edited_time = loader.get_notion_last_edited_time() + if last_edited_time == page_edited_time: + logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow")) + return - # check the page is updated - if last_edited_time != page_edited_time: - document.indexing_status = "parsing" - document.processing_started_at = naive_utc_now() - session.commit() + logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green")) - # delete all document segment and index - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() + try: + index_processor = IndexProcessorFactory(index_type).init_index_processor() + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green")) + except Exception: + logger.exception("Failed to clean vector index for document %s", document_id) - segments = session.scalars( - select(DocumentSegment).where(DocumentSegment.document_id == document_id) - ).all() - index_node_ids = [segment.index_node_id for segment in segments] + with session_factory.create_session() as session, session.begin(): + document = session.query(Document).filter_by(id=document_id).first() + if not document: + logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow")) + return - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) + data_source_info = document.data_source_info_dict + data_source_info["last_edited_time"] = last_edited_time + document.data_source_info = data_source_info - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) + document.indexing_status = "parsing" + document.processing_started_at = naive_utc_now() - end_at = time.perf_counter() - logger.info( - click.style( - "Cleaned document when document update data source or process rule: {} latency: {}".format( - document_id, end_at - start_at - ), - fg="green", - ) - ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) + logger.info(click.style(f"Deleted segments for document {document_id}", fg="green")) + + try: + indexing_runner = IndexingRunner() + with session_factory.create_session() as session: + document = session.query(Document).filter_by(id=document_id).first() + if document: + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception as e: + logger.exception("document_indexing_sync_task failed for document_id: %s", document_id) + with session_factory.create_session() as session, session.begin(): + document = session.query(Document).filter_by(id=document_id).first() + if document: + document.indexing_status = "error" + document.error = str(e) + document.stopped_at = naive_utc_now() diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 34496e9c6f..11edcf151f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.commit() return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) + # Phase 1: Update status to parsing (short transaction) + with session_factory.create_session() as session, session.begin(): + documents = ( + session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + ) + for document in documents: if document: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) session.add(document) - session.commit() + # Transaction committed and closed - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions) + has_error = False + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + has_error = True + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) + has_error = True + if not has_error: + with session_factory.create_session() as session: # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) @@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # expire all session to get latest document's indexing status session.expire_all() # Check each document's indexing status and trigger summary generation if completed - for document_id in document_ids: - # Re-query document to get latest status (IndexingRunner may have updated it) - document = ( - session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() - ) + + documents = ( + session.query(Document) + .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + .all() + ) + + for document in documents: if document: logger.info( "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, @@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): and document.need_summary is True ): try: - generate_summary_index_task.delay(dataset.id, document_id, None) + generate_summary_index_task.delay(dataset.id, document.id, None) logger.info( "Queued summary index generation task for document %s in dataset %s " "after indexing completed", - document_id, + document.id, dataset.id, ) except Exception: logger.exception( "Failed to queue summary index generation task for document %s", - document_id, + document.id, ) # Don't fail the entire indexing process if summary task queuing fails else: logger.info( "Skipping summary generation for document %s: " "status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, ) else: - logger.warning("Document %s not found after indexing", document_id) - else: - logger.info( - "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", - dataset.id, - summary_index_setting.get("enable") if summary_index_setting else None, - ) + logger.warning("Document %s not found after indexing", document.id) else: logger.info( "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", dataset.id, dataset.indexing_technique, ) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 67a23be952..c7508c6d05 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -8,7 +8,6 @@ from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start update document: {document_id}", fg="green")) start_at = time.perf_counter() - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: @@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - session.commit() - # delete all document segment and index - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + return - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() + index_type = document.doc_form + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + index_node_ids = [segment.index_node_id for segment in segments] - segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) - db.session.commit() + clean_success = False + try: + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if index_node_ids: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() logger.info( click.style( @@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str): fg="green", ) ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + clean_success = True + except Exception: + logger.exception("Failed to clean document index during update, document_id: %s", document_id) - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) + if clean_success: + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 817249845a..6240f2200f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): - def del_workflow_archive_log(workflow_archive_log_id: str): - db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( + def del_workflow_archive_log(session, workflow_archive_log_id: str): + session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( synchronize_session=False ) @@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int: total_files_deleted = 0 while True: - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): # Get a batch of draft variable IDs along with their file_ids query_sql = """ SELECT id, file_id FROM workflow_draft_variables diff --git a/api/tasks/workflow_draft_var_tasks.py b/api/tasks/workflow_draft_var_tasks.py index fcb98ec39e..26f8f7c29e 100644 --- a/api/tasks/workflow_draft_var_tasks.py +++ b/api/tasks/workflow_draft_var_tasks.py @@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers. """ from celery import shared_task # type: ignore[import-untyped] -from sqlalchemy.orm import Session -from extensions.ext_database import db +from core.db.session_factory import session_factory from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService @@ -17,6 +16,6 @@ def save_workflow_execution_task( self, deletions: list[DraftVarFileDeletion], ): - with Session(bind=db.engine) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): srv = WorkflowDraftVariableService(session=session) srv.delete_workflow_draft_variable_file(deletions=deletions) diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index f46d1bf5db..d020233620 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -10,7 +10,10 @@ from models import Tenant from models.enums import CreatorUserRole from models.model import App, UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile -from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch +from tasks.remove_app_and_related_data_task import ( + _delete_draft_variables, + delete_draft_variables_batch, +) @pytest.fixture @@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.return_value = None with session_factory.create_session() as session: draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() - var_files_before = session.query(WorkflowDraftVariableFile).count() - upload_files_before = session.query(UploadFile).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert draft_vars_before == 3 assert var_files_before == 2 assert upload_files_before == 2 @@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): data = setup_offload_test_data app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + variable_file_ids = [vf.id for vf in data["variable_files"]] mock_storage.delete.side_effect = [Exception("Storage error"), None] deleted_count = delete_draft_variables_batch(app_id, batch_size=10) @@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration: assert draft_vars_after == 0 with session_factory.create_session() as session: - var_files_after = session.query(WorkflowDraftVariableFile).count() - upload_files_after = session.query(UploadFile).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_(variable_file_ids)) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() assert var_files_after == 0 assert upload_files_after == 0 @@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration: if app2_obj: session.delete(app2_obj) session.commit() + + +class TestDeleteDraftVariablesSessionCommit: + """Test suite to verify session commit behavior in delete_draft_variables_batch.""" + + @pytest.fixture + def setup_offload_test_data(self, app_and_tenant): + """Create test data with offload files for session commit tests.""" + from core.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now + + tenant, app = app_and_tenant + + with session_factory.create_session() as session: + upload_file1 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file1.json", + name="file1.json", + size=1024, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + upload_file2 = UploadFile( + tenant_id=tenant.id, + storage_type="local", + key="test/file2.json", + name="file2.json", + size=2048, + extension="json", + mime_type="application/json", + created_by_role=CreatorUserRole.ACCOUNT, + created_by=str(uuid.uuid4()), + created_at=naive_utc_now(), + used=False, + ) + session.add(upload_file1) + session.add(upload_file2) + session.flush() + + var_file1 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file1.id, + size=1024, + length=10, + value_type=SegmentType.STRING, + ) + var_file2 = WorkflowDraftVariableFile( + tenant_id=tenant.id, + app_id=app.id, + user_id=str(uuid.uuid4()), + upload_file_id=upload_file2.id, + size=2048, + length=20, + value_type=SegmentType.OBJECT, + ) + session.add(var_file1) + session.add(var_file2) + session.flush() + + draft_var1 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_1", + name="large_var_1", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file1.id, + ) + draft_var2 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_2", + name="large_var_2", + value=StringSegment(value="truncated..."), + node_execution_id=str(uuid.uuid4()), + file_id=var_file2.id, + ) + draft_var3 = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id="node_3", + name="regular_var", + value=StringSegment(value="regular_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(draft_var1) + session.add(draft_var2) + session.add(draft_var3) + session.commit() + + data = { + "app": app, + "tenant": tenant, + "upload_files": [upload_file1, upload_file2], + "variable_files": [var_file1, var_file2], + "draft_variables": [draft_var1, draft_var2, draft_var3], + } + + yield data + + with session_factory.create_session() as session: + for table, ids in [ + (WorkflowDraftVariable, [v.id for v in data["draft_variables"]]), + (WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]), + (UploadFile, [uf.id for uf in data["upload_files"]]), + ]: + cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False) + session.execute(cleanup_query) + session.commit() + + @pytest.fixture + def setup_commit_test_data(self, app_and_tenant): + """Create test data for session commit tests.""" + tenant, app = app_and_tenant + variable_ids: list[str] = [] + + with session_factory.create_session() as session: + variables = [] + for i in range(10): + var = WorkflowDraftVariable.new_node_variable( + app_id=app.id, + node_id=f"node_{i}", + name=f"var_{i}", + value=StringSegment(value="test_value"), + node_execution_id=str(uuid.uuid4()), + ) + session.add(var) + variables.append(var) + session.commit() + variable_ids = [v.id for v in variables] + + yield { + "app": app, + "tenant": tenant, + "variable_ids": variable_ids, + } + + with session_factory.create_session() as session: + cleanup_query = ( + delete(WorkflowDraftVariable) + .where(WorkflowDraftVariable.id.in_(variable_ids)) + .execution_options(synchronize_session=False) + ) + session.execute(cleanup_query) + session.commit() + + def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data): + """Test that session.begin() is used for automatic transaction management.""" + data = setup_commit_test_data + app_id = data["app"].id + + # Since session.begin() is used, the transaction is automatically committed + # when the with block exits successfully. We verify this by checking that + # data is actually persisted. + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + # Verify all data was deleted (proves transaction was committed) + with session_factory.create_session() as session: + remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + + assert deleted_count == 10 + assert remaining_count == 0 + + def test_data_persisted_after_batch_deletion(self, setup_commit_test_data): + """Test that data is actually persisted to database after batch deletion with commits.""" + data = setup_commit_test_data + app_id = data["app"].id + variable_ids = data["variable_ids"] + + # Verify initial state + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Perform deletion with small batch size to force multiple commits + deleted_count = delete_draft_variables_batch(app_id, batch_size=3) + + assert deleted_count == 10 + + # Verify all data is deleted in a new session (proves commits worked) + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + # Verify specific IDs are deleted + with session_factory.create_session() as session: + remaining_vars = ( + session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count() + ) + assert remaining_vars == 0 + + def test_session_commit_with_empty_dataset(self, setup_commit_test_data): + """Test session behavior when deleting from an empty dataset.""" + nonexistent_app_id = str(uuid.uuid4()) + + # Should not raise any errors and should return 0 + deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10) + assert deleted_count == 0 + + def test_session_commit_with_single_batch(self, setup_commit_test_data): + """Test that commit happens correctly when all data fits in a single batch.""" + data = setup_commit_test_data + app_id = data["app"].id + + with session_factory.create_session() as session: + initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert initial_count == 10 + + # Delete all in a single batch + deleted_count = delete_draft_variables_batch(app_id, batch_size=100) + assert deleted_count == 10 + + # Verify data is persisted + with session_factory.create_session() as session: + final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + assert final_count == 0 + + def test_invalid_batch_size_raises_error(self, setup_commit_test_data): + """Test that invalid batch size raises ValueError.""" + data = setup_commit_test_data + app_id = data["app"].id + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + delete_draft_variables_batch(app_id, batch_size=-1) + + @patch("extensions.ext_storage.storage") + def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data): + """Test that session commits correctly when cleaning up offload data.""" + data = setup_offload_test_data + app_id = data["app"].id + upload_file_ids = [uf.id for uf in data["upload_files"]] + mock_storage.delete.return_value = None + + # Verify initial state + with session_factory.create_session() as session: + draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_before = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_before == 3 + assert var_files_before == 2 + assert upload_files_before == 2 + + # Delete variables with offload data + deleted_count = delete_draft_variables_batch(app_id, batch_size=10) + assert deleted_count == 3 + + # Verify all data is persisted (deleted) in new session + with session_factory.create_session() as session: + draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() + var_files_after = ( + session.query(WorkflowDraftVariableFile) + .where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]])) + .count() + ) + upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count() + assert draft_vars_after == 0 + assert var_files_after == 0 + assert upload_files_after == 0 + + # Verify storage cleanup was called + assert mock_storage.delete.call_count == 2 diff --git a/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py b/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py new file mode 100644 index 0000000000..eb055ca332 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py @@ -0,0 +1,38 @@ +""" +Integration tests for DbMigrationAutoRenewLock using real Redis via TestContainers. +""" + +import time +import uuid + +import pytest + +from extensions.ext_redis import redis_client +from libs.db_migration_lock import DbMigrationAutoRenewLock + + +@pytest.mark.usefixtures("flask_app_with_containers") +def test_db_migration_lock_renews_ttl_and_releases(): + lock_name = f"test:db_migration_auto_renew_lock:{uuid.uuid4().hex}" + + # Keep base TTL very small, and renew frequently so the test is stable even on slower CI. + lock = DbMigrationAutoRenewLock( + redis_client=redis_client, + name=lock_name, + ttl_seconds=1.0, + renew_interval_seconds=0.2, + log_context="test_db_migration_lock", + ) + + acquired = lock.acquire(blocking=True, blocking_timeout=5) + assert acquired is True + + # Wait beyond the base TTL; key should still exist due to renewal. + time.sleep(1.5) + ttl = redis_client.ttl(lock_name) + assert ttl > 0 + + lock.release_safely(status="successful") + + # After release, the key should not exist. + assert redis_client.exists(lock_name) == 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 1b844d6357..61f6b75b10 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask: mock_storage.download.side_effect = mock_download - # Execute the task + # Execute the task - should raise ValueError for empty CSV job_id = str(uuid.uuid4()) - batch_create_segment_to_index_task( - job_id=job_id, - upload_file_id=upload_file.id, - dataset_id=dataset.id, - document_id=document.id, - tenant_id=tenant.id, - user_id=account.id, - ) + with pytest.raises(ValueError, match="The CSV file is empty"): + batch_create_segment_to_index_task( + job_id=job_id, + upload_file_id=upload_file.id, + dataset_id=dataset.id, + document_id=document.id, + tenant_id=tenant.id, + user_id=account.id, + ) # Verify error handling - # Check Redis cache was set to error status - from extensions.ext_redis import redis_client - - cache_key = f"segment_batch_import_{job_id}" - cache_value = redis_client.get(cache_key) - assert cache_value == b"error" - - # Verify no segments were created + # Since exception was raised, no segments should be created from extensions.ext_database import db segments = db.session.query(DocumentSegment).all() diff --git a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py index eec6929925..379986c191 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_clean_notion_document_task.py @@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask: # Execute cleanup task clean_notion_document_task(document_ids, dataset.id) - # Verify documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0 + # Verify segments are deleted assert ( db_session_with_containers.query(DocumentSegment) .filter(DocumentSegment.document_id.in_(document_ids)) @@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask: == 0 ) - # Verify index processor was called for each document + # Verify index processor was called mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - assert mock_processor.clean.call_count == len(document_ids) + mock_processor.clean.assert_called_once() # This test successfully verifies: # 1. Document records are properly deleted from the database @@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask: non_existent_dataset_id = str(uuid.uuid4()) document_ids = [str(uuid.uuid4()), str(uuid.uuid4())] - # Execute cleanup task with non-existent dataset - clean_notion_document_task(document_ids, non_existent_dataset_id) + # Execute cleanup task with non-existent dataset - expect exception + with pytest.raises(Exception, match="Document has no dataset"): + clean_notion_document_task(document_ids, non_existent_dataset_id) - # Verify that the index processor was not called - mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_processor.clean.assert_not_called() + # Verify that the index processor factory was not used + mock_index_processor_factory.return_value.init_index_processor.assert_not_called() def test_clean_notion_document_task_empty_document_list( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies @@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask: # Execute cleanup task with empty document list clean_notion_document_task([], dataset.id) - # Verify that the index processor was not called + # Verify that the index processor was called once with empty node list mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value - mock_processor.clean.assert_not_called() + assert mock_processor.clean.call_count == 1 + args, kwargs = mock_processor.clean.call_args + # args: (dataset, total_index_node_ids) + assert isinstance(args[0], Dataset) + assert args[1] == [] def test_clean_notion_document_task_with_different_index_types( self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies @@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask: # Note: This test successfully verifies cleanup with different document types. # The task properly handles various index types and document configurations. - # Verify documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0 + # Verify segments are deleted assert ( db_session_with_containers.query(DocumentSegment) .filter(DocumentSegment.document_id == document.id) @@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask: # Execute cleanup task clean_notion_document_task([document.id], dataset.id) - # Verify documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0 + # Verify segments are deleted assert ( db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() == 0 @@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task(documents_to_clean, dataset.id) - # Verify only specified documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0 + # Verify only specified documents' segments are deleted assert ( db_session_with_containers.query(DocumentSegment) .filter(DocumentSegment.document_id.in_(documents_to_clean)) @@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask: db_session_with_containers.commit() # Mock index processor to raise an exception - mock_index_processor = mock_index_processor_factory.init_index_processor.return_value + mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_index_processor.clean.side_effect = Exception("Index processor error") - # Execute cleanup task - it should handle the exception gracefully - clean_notion_document_task([document.id], dataset.id) + # Execute cleanup task - current implementation propagates the exception + with pytest.raises(Exception, match="Index processor error"): + clean_notion_document_task([document.id], dataset.id) # Note: This test demonstrates the task's error handling capability. # Even with external service errors, the database operations complete successfully. @@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask: all_document_ids = [doc.id for doc in documents] clean_notion_document_task(all_document_ids, dataset.id) - # Verify all documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0 + # Verify all segments are deleted assert ( db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() == 0 @@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask: clean_notion_document_task([target_document.id], target_dataset.id) - # Verify only documents from target dataset are deleted - assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0 + # Verify only documents' segments from target dataset are deleted assert ( db_session_with_containers.query(DocumentSegment) .filter(DocumentSegment.document_id == target_document.id) @@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask: all_document_ids = [doc.id for doc in documents] clean_notion_document_task(all_document_ids, dataset.id) - # Verify all documents and segments are deleted regardless of status - assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0 + # Verify all segments are deleted regardless of status assert ( db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() == 0 @@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask: # Execute cleanup task clean_notion_document_task([document.id], dataset.id) - # Verify documents and segments are deleted - assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0 + # Verify segments are deleted assert ( db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() == 0 diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py new file mode 100644 index 0000000000..7f37f84113 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -0,0 +1,182 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_update_task import document_indexing_update_task + + +class TestDocumentIndexingUpdateTask: + @pytest.fixture + def mock_external_dependencies(self): + """Patch external collaborators used by the update task. + - IndexProcessorFactory.init_index_processor().clean(...) + - IndexingRunner.run([...]) + """ + with ( + patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory, + patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner, + ): + processor_instance = MagicMock() + mock_factory.return_value.init_index_processor.return_value = processor_instance + + runner_instance = MagicMock() + mock_runner.return_value = runner_instance + + yield { + "factory": mock_factory, + "processor": processor_instance, + "runner": mock_runner, + "runner_instance": runner_instance, + } + + def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + fake = Faker() + + # Account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + # Dataset and document + dataset = Dataset( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=64), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + document = Document( + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Segments + node_ids = [] + for i in range(segment_count): + node_id = f"node-{i + 1}" + seg = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + content=fake.text(max_nb_chars=32), + answer=None, + word_count=10, + tokens=5, + index_node_id=node_id, + status="completed", + created_by=account.id, + ) + db_session_with_containers.add(seg) + node_ids.append(node_id) + db_session_with_containers.commit() + + # Refresh to ensure ORM state + db_session_with_containers.refresh(dataset) + db_session_with_containers.refresh(document) + + return dataset, document, node_ids + + def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) + + # Act + document_indexing_update_task(dataset.id, document.id) + + # Ensure we see committed changes from another session + db_session_with_containers.expire_all() + + # Assert document status updated before reindex + updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + assert updated.indexing_status == "parsing" + assert updated.processing_started_at is not None + + # Segments should be deleted + remaining = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + ) + assert remaining == 0 + + # Assert index processor clean was called with expected args + clean_call = mock_external_dependencies["processor"].clean.call_args + assert clean_call is not None + args, kwargs = clean_call + # args[0] is a Dataset instance (from another session) — validate by id + assert getattr(args[0], "id", None) == dataset.id + # args[1] should contain our node_ids + assert set(args[1]) == set(node_ids) + assert kwargs.get("with_keywords") is True + assert kwargs.get("delete_child_chunks") is True + + # Assert indexing runner invoked with the updated document + run_call = mock_external_dependencies["runner_instance"].run.call_args + assert run_call is not None + run_docs = run_call[0][0] + assert len(run_docs) == 1 + first = run_docs[0] + assert getattr(first, "id", None) == document.id + + def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) + + # Force clean to raise; task should continue to indexing + mock_external_dependencies["processor"].clean.side_effect = Exception("boom") + + document_indexing_update_task(dataset.id, document.id) + + # Ensure we see committed changes from another session + db_session_with_containers.expire_all() + + # Indexing should still be triggered + mock_external_dependencies["runner_instance"].run.assert_called_once() + + # Segments should remain (since clean failed before DB delete) + remaining = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + ) + assert remaining > 0 + + def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + fake = Faker() + # Act with non-existent document id + document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) + + # Neither processor nor runner should be called + mock_external_dependencies["processor"].clean.assert_not_called() + mock_external_dependencies["runner_instance"].run.assert_not_called() diff --git a/api/tests/unit_tests/commands/test_upgrade_db.py b/api/tests/unit_tests/commands/test_upgrade_db.py new file mode 100644 index 0000000000..80173f5d46 --- /dev/null +++ b/api/tests/unit_tests/commands/test_upgrade_db.py @@ -0,0 +1,146 @@ +import sys +import threading +import types +from unittest.mock import MagicMock + +import commands +from libs.db_migration_lock import LockNotOwnedError, RedisError + +HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0 + + +def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None: + module = types.ModuleType("flask_migrate") + module.upgrade = upgrade_impl + monkeypatch.setitem(sys.modules, "flask_migrate", module) + + +def _invoke_upgrade_db() -> int: + try: + commands.upgrade_db.callback() + except SystemExit as e: + return int(e.code or 0) + return 0 + + +def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys): + monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234) + + lock = MagicMock() + lock.acquire.return_value = False + commands.redis_client.lock.return_value = lock + + exit_code = _invoke_upgrade_db() + captured = capsys.readouterr() + + assert exit_code == 0 + assert "Database migration skipped" in captured.out + + commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False) + lock.acquire.assert_called_once_with(blocking=False) + lock.release.assert_not_called() + + +def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): + monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321) + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = LockNotOwnedError("simulated") + commands.redis_client.lock.return_value = lock + + def _upgrade(): + raise RuntimeError("boom") + + _install_fake_flask_migrate(monkeypatch, _upgrade) + + exit_code = _invoke_upgrade_db() + captured = capsys.readouterr() + + assert exit_code == 1 + assert "Database migration failed: boom" in captured.out + + commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False) + lock.acquire.assert_called_once_with(blocking=False) + lock.release.assert_called_once() + + +def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys): + monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999) + + lock = MagicMock() + lock.acquire.return_value = True + lock.release.side_effect = LockNotOwnedError("simulated") + commands.redis_client.lock.return_value = lock + + _install_fake_flask_migrate(monkeypatch, lambda: None) + + exit_code = _invoke_upgrade_db() + captured = capsys.readouterr() + + assert exit_code == 0 + assert "Database migration successful!" in captured.out + + commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False) + lock.acquire.assert_called_once_with(blocking=False) + lock.release.assert_called_once() + + +def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys): + """ + Ensure the lock is renewed while migrations are running, so the base TTL can stay short. + """ + + # Use a small TTL so the heartbeat interval triggers quickly. + monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + + lock = MagicMock() + lock.acquire.return_value = True + commands.redis_client.lock.return_value = lock + + renewed = threading.Event() + + def _reacquire(): + renewed.set() + return True + + lock.reacquire.side_effect = _reacquire + + def _upgrade(): + assert renewed.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS) + + _install_fake_flask_migrate(monkeypatch, _upgrade) + + exit_code = _invoke_upgrade_db() + _ = capsys.readouterr() + + assert exit_code == 0 + assert lock.reacquire.call_count >= 1 + + +def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys): + # Use a small TTL so heartbeat runs during the upgrade call. + monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3) + + lock = MagicMock() + lock.acquire.return_value = True + commands.redis_client.lock.return_value = lock + + attempted = threading.Event() + + def _reacquire(): + attempted.set() + raise RedisError("simulated") + + lock.reacquire.side_effect = _reacquire + + def _upgrade(): + assert attempted.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS) + + _install_fake_flask_migrate(monkeypatch, _upgrade) + + exit_code = _invoke_upgrade_db() + _ = capsys.readouterr() + + assert exit_code == 0 + assert lock.reacquire.call_count >= 1 diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py index 5a43a247e3..c0c636715d 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_message_cycle_manager_optimization.py @@ -25,15 +25,19 @@ class TestMessageCycleManagerOptimization: task_state = Mock() return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) - def test_get_message_event_type_with_message_file(self, message_cycle_manager): - """Test get_message_event_type returns MESSAGE_FILE when message has files.""" + def test_get_message_event_type_with_assistant_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE_FILE when message has assistant-generated files. + + This ensures that AI-generated images (belongs_to='assistant') trigger the MESSAGE_FILE event, + allowing the frontend to properly display generated image files with url field. + """ with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: # Setup mock session and message file mock_session = Mock() mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.scalar(select(...)) + mock_message_file.belongs_to = "assistant" mock_session.scalar.return_value = mock_message_file # Execute @@ -44,6 +48,31 @@ class TestMessageCycleManagerOptimization: assert result == StreamEvent.MESSAGE_FILE mock_session.scalar.assert_called_once() + def test_get_message_event_type_with_user_file(self, message_cycle_manager): + """Test get_message_event_type returns MESSAGE when message only has user-uploaded files. + + This is a regression test for the issue where user-uploaded images (belongs_to='user') + caused the LLM text response to be incorrectly tagged with MESSAGE_FILE event, + resulting in broken images in the chat UI. The query filters for belongs_to='assistant', + so when only user files exist, the database query returns None, resulting in MESSAGE event type. + """ + with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: + # Setup mock session and message file + mock_session = Mock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + + # When querying for assistant files with only user files present, return None + # (simulates database query with belongs_to='assistant' filter returning no results) + mock_session.scalar.return_value = None + + # Execute + with current_app.app_context(): + result = message_cycle_manager.get_message_event_type("test-message-id") + + # Assert + assert result == StreamEvent.MESSAGE + mock_session.scalar.assert_called_once() + def test_get_message_event_type_without_message_file(self, message_cycle_manager): """Test get_message_event_type returns MESSAGE when message has no files.""" with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: @@ -69,7 +98,7 @@ class TestMessageCycleManagerOptimization: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_message_file = Mock() - # Current implementation uses session.scalar(select(...)) + mock_message_file.belongs_to = "assistant" mock_session.scalar.return_value = mock_message_file # Execute: compute event type once, then pass to message_to_stream_response diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py b/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index 7c0eccbb8b..f12e5993dc 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -4,7 +4,7 @@ from typing import Any from uuid import uuid4 import pytest -from hypothesis import given, settings +from hypothesis import HealthCheck, given, settings from hypothesis import strategies as st from core.file import File, FileTransferMethod, FileType @@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]: ) -@settings(max_examples=50) +@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None) @given(_scalar_value()) def test_build_segment_and_extract_values_for_scalar_types(value): seg = variable_factory.build_segment(value) @@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value): assert seg.value == value -@settings(max_examples=50) +@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None) @given(values=st.lists(_scalar_value(), max_size=20)) def test_build_segment_and_extract_values_for_array_types(values): seg = variable_factory.build_segment(values) diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index e24ef32a24..8d8e2b0db0 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id): def mock_db_session(): """Mock database session via session_factory.create_session().""" with patch("tasks.document_indexing_task.session_factory") as mock_sf: - session = MagicMock() - # Ensure tests that expect session.close() to be called can observe it via the context manager - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - # Link __exit__ to session.close so "close" expectations reflect context manager teardown + sessions = [] # Track all created sessions + # Shared mock data that all sessions will access + shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None} - def _exit_side_effect(*args, **kwargs): - session.close() + def create_session_side_effect(): + session = MagicMock() + session.close = MagicMock() - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm + # Track commit calls + commit_mock = MagicMock() + session.commit = commit_mock + cm = MagicMock() + cm.__enter__.return_value = session - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - yield session + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + + # Support session.begin() for transactions + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def begin_exit_side_effect(*args, **kwargs): + # Auto-commit on transaction exit (like SQLAlchemy) + session.commit() + # Also mark wrapper's commit as called + if sessions: + sessions[0].commit() + + begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect) + session.begin = MagicMock(return_value=begin_cm) + + sessions.append(session) + + # Setup query with side_effect to handle both Dataset and Document queries + def query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + # Support both .first() and .all() calls with chaining + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + # Create an iterator for .first() calls if not exists + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + session.query = MagicMock(side_effect=query_side_effect) + return cm + + mock_sf.create_session.side_effect = create_session_side_effect + + # Create a wrapper that behaves like the first session but has access to all sessions + class SessionWrapper: + def __init__(self): + self._sessions = sessions + self._shared_data = shared_mock_data + # Create a default session for setup phase + self._default_session = MagicMock() + self._default_session.close = MagicMock() + self._default_session.commit = MagicMock() + + # Support session.begin() for default session too + begin_cm = MagicMock() + begin_cm.__enter__.return_value = self._default_session + + def default_begin_exit_side_effect(*args, **kwargs): + self._default_session.commit() + + begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect) + self._default_session.begin = MagicMock(return_value=begin_cm) + + def default_query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + self._default_session.query = MagicMock(side_effect=default_query_side_effect) + + def __getattr__(self, name): + # Forward all attribute access to the first session, or default if none created yet + target_session = self._sessions[0] if self._sessions else self._default_session + return getattr(target_session, name) + + @property + def all_sessions(self): + """Access all created sessions for testing.""" + return self._sessions + + wrapper = SessionWrapper() + yield wrapper @pytest.fixture @@ -252,18 +356,9 @@ class TestTaskEnqueuing: use the deprecated function. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -304,21 +399,9 @@ class TestBatchProcessing: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create an iterator for documents - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -357,19 +440,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL @@ -407,19 +480,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX @@ -444,7 +507,10 @@ class TestBatchProcessing: """ # Arrange document_ids = [] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -482,19 +548,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -528,19 +584,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -635,19 +681,9 @@ class TestErrorHandling: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set up to trigger vector space limit error mock_feature_service.get_features.return_value.billing.enabled = True @@ -674,17 +710,9 @@ class TestErrorHandling: Errors during indexing should be caught and logged, but not crash the task. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Indexing failed") @@ -708,17 +736,9 @@ class TestErrorHandling: but not treated as a failure. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise DocumentIsPausedError mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") @@ -853,17 +873,9 @@ class TestTaskCancellation: Session cleanup should happen in finally block. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -883,17 +895,9 @@ class TestTaskCancellation: Session cleanup should happen even when errors occur. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Test error") @@ -962,6 +966,7 @@ class TestAdvancedScenarios: document_ids = [str(uuid.uuid4()) for _ in range(3)] # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents mock_documents = [] for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one doc = MagicMock(spec=Document) @@ -971,21 +976,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create iterator that returns None for missing document - doc_responses = [mock_documents[0], None, mock_documents[1]] - doc_iter = iter(doc_responses) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1075,19 +1068,9 @@ class TestAdvancedScenarios: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space exactly at limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1219,19 +1202,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Billing disabled - limits should not be checked mock_feature_service.get_features.return_value.billing.enabled = False @@ -1273,19 +1246,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1321,19 +1284,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1415,17 +1368,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1465,17 +1410,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1555,19 +1492,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space limit to 0 (unlimited) mock_feature_service.get_features.return_value.billing.enabled = True @@ -1612,19 +1539,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set negative vector space limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1675,19 +1592,9 @@ class TestPerformanceScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Configure billing with sufficient limits mock_feature_service.get_features.return_value.billing.enabled = True @@ -1826,19 +1733,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") @@ -1866,7 +1763,7 @@ class TestRobustness: - No exceptions occur Expected behavior: - - Database session is closed + - All database sessions are closed - No connection leaks """ # Arrange @@ -1879,19 +1776,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1899,10 +1786,11 @@ class TestRobustness: # Act _document_indexing(dataset_id, document_ids) - # Assert - assert mock_db_session.close.called - # Verify close is called exactly once - assert mock_db_session.close.call_count == 1 + # Assert - All created sessions should be closed + # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary) + assert len(mock_db_session.all_sessions) >= 1 + for session in mock_db_session.all_sessions: + assert session.close.called, "All sessions should be closed" def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): """ diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index fa33034f40..549f2c6c9b 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -109,25 +109,87 @@ def mock_document_segments(document_id): @pytest.fixture def mock_db_session(): - """Mock database session via session_factory.create_session().""" + """Mock database session via session_factory.create_session(). + + After session split refactor, the code calls create_session() multiple times. + This fixture creates shared query mocks so all sessions use the same + query configuration, simulating database persistence across sessions. + + The fixture automatically converts side_effect to cycle to prevent StopIteration. + Tests configure mocks the same way as before, but behind the scenes the values + are cycled infinitely for all sessions. + """ + from itertools import cycle + with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: - session = MagicMock() - # Ensure tests can observe session.close() via context manager teardown - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session + sessions = [] - def _exit_side_effect(*args, **kwargs): - session.close() + # Shared query mocks - all sessions use these + shared_query = MagicMock() + shared_filter_by = MagicMock() + shared_scalars_result = MagicMock() - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm + # Create custom first mock that auto-cycles side_effect + class CyclicMock(MagicMock): + def __setattr__(self, name, value): + if name == "side_effect" and value is not None: + # Convert list/tuple to infinite cycle + if isinstance(value, (list, tuple)): + value = cycle(value) + super().__setattr__(name, value) - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - session.scalars.return_value = MagicMock() - yield session + shared_query.where.return_value.first = CyclicMock() + shared_filter_by.first = CyclicMock() + + def _create_session(): + """Create a new mock session for each create_session() call.""" + session = MagicMock() + session.close = MagicMock() + session.commit = MagicMock() + + # Mock session.begin() context manager + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def _begin_exit_side_effect(exc_type, exc, tb): + # commit on success + if exc_type is None: + session.commit() + # return False to propagate exceptions + return False + + begin_cm.__exit__.side_effect = _begin_exit_side_effect + session.begin.return_value = begin_cm + + # Mock create_session() context manager + cm = MagicMock() + cm.__enter__.return_value = session + + def _exit_side_effect(exc_type, exc, tb): + session.close() + return False + + cm.__exit__.side_effect = _exit_side_effect + + # All sessions use the same shared query mocks + session.query.return_value = shared_query + shared_query.where.return_value = shared_query + shared_query.filter_by.return_value = shared_filter_by + session.scalars.return_value = shared_scalars_result + + sessions.append(session) + # Attach helpers on the first created session for assertions across all sessions + if len(sessions) == 1: + session.get_all_sessions = lambda: sessions + session.any_close_called = lambda: any(s.close.called for s in sessions) + session.any_commit_called = lambda: any(s.commit.called for s in sessions) + return cm + + mock_sf.create_session.side_effect = _create_session + + # Create first session and return it + _create_session() + yield sessions[0] @pytest.fixture @@ -186,8 +248,8 @@ class TestDocumentIndexingSyncTask: # Act document_indexing_sync_task(dataset_id, document_id) - # Assert - mock_db_session.close.assert_called_once() + # Assert - at least one session should have been closed + assert mock_db_session.any_close_called() def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): """Test that task raises error when notion_workspace_id is missing.""" @@ -230,6 +292,7 @@ class TestDocumentIndexingSyncTask: """Test that task handles missing credentials by updating document status.""" # Arrange mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document mock_datasource_provider_service.get_datasource_credentials.return_value = None # Act @@ -239,8 +302,8 @@ class TestDocumentIndexingSyncTask: assert mock_document.indexing_status == "error" assert "Datasource credential not found" in mock_document.error assert mock_document.stopped_at is not None - mock_db_session.commit.assert_called() - mock_db_session.close.assert_called() + assert mock_db_session.any_commit_called() + assert mock_db_session.any_close_called() def test_page_not_updated( self, @@ -254,6 +317,7 @@ class TestDocumentIndexingSyncTask: """Test that task does nothing when page has not been updated.""" # Arrange mock_db_session.query.return_value.where.return_value.first.return_value = mock_document + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document # Return same time as stored in document mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" @@ -263,8 +327,8 @@ class TestDocumentIndexingSyncTask: # Assert # Document status should remain unchanged assert mock_document.indexing_status == "completed" - # Session should still be closed via context manager teardown - assert mock_db_session.close.called + # At least one session should have been closed via context manager teardown + assert mock_db_session.any_close_called() def test_successful_sync_when_page_updated( self, @@ -281,7 +345,20 @@ class TestDocumentIndexingSyncTask: ): """Test successful sync flow when Notion page has been updated.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + # Set exact sequence of returns across calls to `.first()`: + # 1) document (initial fetch) + # 2) dataset (pre-check) + # 3) dataset (cleaning phase) + # 4) document (pre-indexing update) + # 5) document (indexing runner fetch) + mock_db_session.query.return_value.where.return_value.first.side_effect = [ + mock_document, + mock_dataset, + mock_dataset, + mock_document, + mock_document, + ] + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document mock_db_session.scalars.return_value.all.return_value = mock_document_segments # NotionExtractor returns updated time mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" @@ -299,28 +376,40 @@ class TestDocumentIndexingSyncTask: mock_processor.clean.assert_called_once() # Verify segments were deleted from database in batch (DELETE FROM document_segments) - execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] + # Aggregate execute calls across all created sessions + execute_sqls = [] + for s in mock_db_session.get_all_sessions(): + execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list]) assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) # Verify indexing runner was called mock_indexing_runner.run.assert_called_once_with([mock_document]) - # Verify session operations - assert mock_db_session.commit.called - mock_db_session.close.assert_called_once() + # Verify session operations (across any created session) + assert mock_db_session.any_commit_called() + assert mock_db_session.any_close_called() def test_dataset_not_found_during_cleaning( self, mock_db_session, mock_datasource_provider_service, mock_notion_extractor, + mock_indexing_runner, mock_document, dataset_id, document_id, ): """Test that task handles dataset not found during cleaning phase.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] + # Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing) + mock_db_session.query.return_value.where.return_value.first.side_effect = [ + mock_document, + mock_dataset, + None, + mock_document, + mock_document, + ] + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Act @@ -329,8 +418,8 @@ class TestDocumentIndexingSyncTask: # Assert # Document should still be set to parsing assert mock_document.indexing_status == "parsing" - # Session should be closed after error - mock_db_session.close.assert_called_once() + # At least one session should be closed after error + assert mock_db_session.any_close_called() def test_cleaning_error_continues_to_indexing( self, @@ -346,8 +435,14 @@ class TestDocumentIndexingSyncTask: ): """Test that indexing continues even if cleaning fails.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] - mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error") + from itertools import cycle + + mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document + # Make the cleaning step fail but not the segment fetch + processor = mock_index_processor_factory.return_value.init_index_processor.return_value + processor.clean.side_effect = Exception("Cleaning error") + mock_db_session.scalars.return_value.all.return_value = [] mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" # Act @@ -356,7 +451,7 @@ class TestDocumentIndexingSyncTask: # Assert # Indexing should still be attempted despite cleaning error mock_indexing_runner.run.assert_called_once_with([mock_document]) - mock_db_session.close.assert_called_once() + assert mock_db_session.any_close_called() def test_indexing_runner_document_paused_error( self, @@ -373,7 +468,10 @@ class TestDocumentIndexingSyncTask: ): """Test that DocumentIsPausedError is handled gracefully.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + from itertools import cycle + + mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") @@ -383,7 +481,7 @@ class TestDocumentIndexingSyncTask: # Assert # Session should be closed after handling error - mock_db_session.close.assert_called_once() + assert mock_db_session.any_close_called() def test_indexing_runner_general_error( self, @@ -400,7 +498,10 @@ class TestDocumentIndexingSyncTask: ): """Test that general exceptions during indexing are handled.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + from itertools import cycle + + mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset]) + mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_indexing_runner.run.side_effect = Exception("Indexing error") @@ -410,7 +511,7 @@ class TestDocumentIndexingSyncTask: # Assert # Session should be closed after error - mock_db_session.close.assert_called_once() + assert mock_db_session.any_close_called() def test_notion_extractor_initialized_with_correct_params( self, @@ -517,7 +618,14 @@ class TestDocumentIndexingSyncTask: ): """Test that index processor clean is called with correct parameters.""" # Arrange - mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] + # Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing) + mock_db_session.query.return_value.where.return_value.first.side_effect = [ + mock_document, + mock_dataset, + mock_dataset, + mock_document, + mock_document, + ] mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" diff --git a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py index a14bbb01d0..2b11e42cd5 100644 --- a/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/unit_tests/tasks/test_remove_app_and_related_data_task.py @@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs: mock_query.where.return_value = mock_delete_query mock_db.session.query.return_value = mock_query - delete_func("log-1") + delete_func(mock_db.session, "log-1") mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_query.where.assert_called_once() diff --git a/api/uv.lock b/api/uv.lock index aefb8e91f0..a3b5433952 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1368,7 +1368,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.12.0" +version = "1.12.1" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1653,7 +1653,7 @@ requires-dist = [ { name = "starlette", specifier = "==0.49.1" }, { name = "tiktoken", specifier = "~=0.9.0" }, { name = "transformers", specifier = "~=4.56.1" }, - { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, + { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.17.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -4433,15 +4433,15 @@ wheels = [ [[package]] name = "pdfminer-six" -version = "20251230" +version = "20260107" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "charset-normalizer" }, { name = "cryptography" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" }, + { url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" }, ] [[package]] @@ -6814,12 +6814,12 @@ wheels = [ [[package]] name = "unstructured" -version = "0.16.25" +version = "0.18.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, { name = "beautifulsoup4" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "dataclasses-json" }, { name = "emoji" }, { name = "filetype" }, @@ -6827,6 +6827,7 @@ dependencies = [ { name = "langdetect" }, { name = "lxml" }, { name = "nltk" }, + { name = "numba" }, { name = "numpy" }, { name = "psutil" }, { name = "python-iso639" }, @@ -6839,9 +6840,9 @@ dependencies = [ { name = "unstructured-client" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" }, + { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" }, ] [package.optional-dependencies] diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index aacb551933..cb5e2c47f7 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 1cb327cfe4..1886f848e0 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -707,7 +707,7 @@ services: # API service api: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -749,7 +749,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -788,7 +788,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.12.0 + image: langgenius/dify-api:1.12.1 restart: always environment: # Use the shared environment variables. @@ -818,7 +818,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.12.0 + image: langgenius/dify-web:1.12.1 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 486c0a8ac9..b97aa6e775 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -109,6 +109,7 @@ const AgentTools: FC = () => { tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, + type: tool.provider_type as CollectionType, } } const handleSelectTool = (tool: ToolDefaultValue) => { diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index f5ebaac3ca..44ce5cde52 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -19,19 +19,21 @@ import { useBoolean, useSessionStorageState } from 'ahooks' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useShallow } from 'zustand/react/shallow' import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import { Generator } from '@/app/components/base/icons/src/vender/other' -import Loading from '@/app/components/base/loading' +import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' import Toast from '@/app/components/base/toast' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug' import { useGenerateRuleTemplate } from '@/service/use-apps' +import { useStore } from '../../../store' import IdeaOutput from './idea-output' import InstructionEditorInBasic from './instruction-editor' import InstructionEditorInWorkflow from './instruction-editor-in-workflow' @@ -83,6 +85,9 @@ const GetAutomaticRes: FC = ({ onFinished, }) => { const { t } = useTranslation() + const { appDetail } = useStore(useShallow(state => ({ + appDetail: state.appDetail, + }))) const localModel = localStorage.getItem('auto-gen-model') ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model : null @@ -235,6 +240,7 @@ const GetAutomaticRes: FC = ({ instruction, model_config: model, no_variable: false, + app_id: appDetail?.id, }) apiRes = { ...res, @@ -256,6 +262,7 @@ const GetAutomaticRes: FC = ({ instruction, ideal_output: ideaOutput, model_config: model, + app_id: appDetail?.id, }) apiRes = res if (error) { diff --git a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx index 84d742d734..0beda8f5c8 100644 --- a/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx +++ b/web/app/components/datasets/create/step-two/components/general-chunking-options.tsx @@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC = ({ ))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && (
= ({
))} { - showSummaryIndexSetting && ( + showSummaryIndexSetting && IS_CE_EDITION && (
{t('list.action.sync', { ns: 'datasetDocuments' })}
)} -
onOperate('summary')}> - - {t('list.action.summary', { ns: 'datasetDocuments' })} -
+ { + IS_CE_EDITION && ( +
onOperate('summary')}> + + {t('list.action.summary', { ns: 'datasetDocuments' })} +
+ ) + } )} diff --git a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx index 486ba2ffdf..ca5a56ec2a 100644 --- a/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx +++ b/web/app/components/datasets/documents/detail/completed/common/batch-action.tsx @@ -7,6 +7,7 @@ import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import Divider from '@/app/components/base/divider' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' +import { IS_CE_EDITION } from '@/config' import { cn } from '@/utils/classnames' const i18nPrefix = 'batchAction' @@ -87,7 +88,7 @@ const BatchAction: FC = ({ {t('metadata.metadata', { ns: 'dataset' })} )} - {onBatchSummary && ( + {onBatchSummary && IS_CE_EDITION && (