Merge origin/release/e-1.12.1 into 1.12.1-otel-ee

Sync enterprise 1.12.1 changes:
- feat: implement heartbeat mechanism for database migration lock
- refactor: replace AutoRenewRedisLock with DbMigrationAutoRenewLock
- fix: improve logging for database migration lock release
- fix: make flask upgrade-db fail on error
- fix: include sso_verified in access_mode validation
- fix: inherit web app permission from original app
- fix: make e-1.12.1 enterprise migrations database-agnostic
- fix: get_message_event_type return wrong message type
- refactor: document_indexing_sync_task split db session
- fix: trigger output schema miss
- test: remove unrelated enterprise service test

Conflict resolution:
- Combined OTEL telemetry imports with tool signature import in easy_ui_based_generate_task_pipeline.py
This commit is contained in:
GareArc
2026-03-01 00:18:46 -08:00
56 changed files with 3291 additions and 1980 deletions

View File

@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType from extensions.storage.storage_type import StorageType
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair from libs.rsa import generate_key_pair
@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DB_UPGRADE_LOCK_TTL_SECONDS = 60
@click.command("reset-password", help="Reset the account password.") @click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for") @click.option("--email", prompt=True, help="Account email to reset password for")
@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
@click.command("upgrade-db", help="Upgrade the database") @click.command("upgrade-db", help="Upgrade the database")
def upgrade_db(): def upgrade_db():
click.echo("Preparing database migration...") click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60) lock = DbMigrationAutoRenewLock(
redis_client=redis_client,
name="db_upgrade_lock",
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
logger=logger,
log_context="db_migration",
)
if lock.acquire(blocking=False): if lock.acquire(blocking=False):
migration_succeeded = False
try: try:
click.echo(click.style("Starting database migration.", fg="green")) click.echo(click.style("Starting database migration.", fg="green"))
@ -737,12 +747,16 @@ def upgrade_db():
flask_migrate.upgrade() flask_migrate.upgrade()
migration_succeeded = True
click.echo(click.style("Database migration successful!", fg="green")) click.echo(click.style("Database migration successful!", fg="green"))
except Exception: except Exception as e:
logger.exception("Failed to execute database migration") logger.exception("Failed to execute database migration")
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally: finally:
lock.release() status = "successful" if migration_succeeded else "failed"
lock.release_safely(status=status)
else: else:
click.echo("Database migration skipped") click.echo("Database migration skipped")

View File

@ -1,3 +1,4 @@
import logging
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any, Literal, TypeAlias from typing import Any, Literal, TypeAlias
@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType) register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class AppListQuery(BaseModel): class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@ -499,6 +502,7 @@ class AppListApi(Resource):
select(Workflow).where( select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT, Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids), Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
) )
) )
.scalars() .scalars()
@ -510,12 +514,14 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN, NodeType.TRIGGER_PLUGIN,
} }
for workflow in draft_workflows: for workflow in draft_workflows:
node_id = None
try: try:
for _, node_data in workflow.walk_nodes(): for node_id, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types: if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id)) draft_trigger_app_ids.add(str(workflow.app_id))
break break
except Exception: except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue continue
for app in app_pagination.items: for app in app_pagination.items:
@ -654,6 +660,19 @@ class AppCopyApi(Resource):
) )
session.commit() session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
try:
# Get the original app's access mode
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
access_mode = original_settings.access_mode
except Exception:
# If original app has no settings (old app), default to public to match fallback behavior
access_mode = "public"
# Apply the same access mode to the copied app
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
stmt = select(App).where(App.id == result.app_id) stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt) app = session.scalar(stmt)

View File

@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
current_user, current_tenant_id = current_account_with_tenant() current_user, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {}) payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider( return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id tenant_id=current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"],
account=current_user,
) )

View File

@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.file import helpers as file_helpers
from core.file.enums import FileTransferMethod
from core.model_manager import ModelInstance from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
@ -57,10 +59,11 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit from core.telemetry import emit as telemetry_emit
from core.tools.signature import sign_tool_file
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -473,6 +476,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
metadata=metadata_dict, metadata=metadata_dict,
) )
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
""" """
Agent message to stream response. Agent message to stream response.

View File

@ -64,7 +64,13 @@ class MessageCycleManager:
# Use SQLAlchemy 2.x style session.scalar(select(...)) # Use SQLAlchemy 2.x style session.scalar(select(...))
with session_factory.create_session() as session: with session_factory.create_session() as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id)) message_file = session.scalar(
select(MessageFile)
.where(
MessageFile.message_id == message_id,
)
.where(MessageFile.belongs_to == "assistant")
)
if message_file: if message_file:
self._message_has_file.add(message_id) self._message_has_file.add(message_id)

View File

@ -0,0 +1,213 @@
"""
DB migration Redis lock with heartbeat renewal.
This is intentionally migration-specific. Background renewal is a trade-off that makes sense
for unbounded, blocking operations like DB migrations (DDL/DML) where the main thread cannot
periodically refresh the lock TTL.
Do NOT use this as a general-purpose lock primitive for normal application code. Prefer explicit
lock lifecycle management (e.g. redis-py Lock context manager + `extend()` / `reacquire()` from
the same thread) when execution flow is under control.
"""
from __future__ import annotations
import logging
import threading
from typing import Any
from redis.exceptions import LockNotOwnedError, RedisError
logger = logging.getLogger(__name__)
MIN_RENEW_INTERVAL_SECONDS = 0.1
DEFAULT_RENEW_INTERVAL_DIVISOR = 3
MIN_JOIN_TIMEOUT_SECONDS = 0.5
MAX_JOIN_TIMEOUT_SECONDS = 5.0
JOIN_TIMEOUT_MULTIPLIER = 2.0
class DbMigrationAutoRenewLock:
"""
Redis lock wrapper that automatically renews TTL while held (migration-only).
Notes:
- We force `thread_local=False` when creating the underlying redis-py lock, because the
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
primary error/exit code.
"""
_redis_client: Any
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
*,
logger: logging.Logger | None = None,
log_context: str | None = None,
) -> None:
self._redis_client = redis_client
self._name = name
self._ttl_seconds = float(ttl_seconds)
self._renew_interval_seconds = (
float(renew_interval_seconds)
if renew_interval_seconds is not None
else max(MIN_RENEW_INTERVAL_SECONDS, self._ttl_seconds / DEFAULT_RENEW_INTERVAL_DIVISOR)
)
self._logger = logger or logging.getLogger(__name__)
self._log_context = log_context
self._lock = None
self._stop_event = None
self._thread = None
self._acquired = False
@property
def name(self) -> str:
return self._name
def acquire(self, *args: Any, **kwargs: Any) -> bool:
"""
Acquire the lock and start heartbeat renewal on success.
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
"""
# Prevent accidental double-acquire which could leave the previous heartbeat thread running.
if self._acquired:
raise RuntimeError("DB migration lock is already acquired; call release_safely() before acquiring again.")
# Reuse the lock object if we already created one.
if self._lock is None:
self._lock = self._redis_client.lock(
name=self._name,
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()
return acquired
def owned(self) -> bool:
if self._lock is None:
return False
try:
return bool(self._lock.owned())
except Exception:
# Ownership checks are best-effort and must not break callers.
return False
def _start_heartbeat(self) -> None:
if self._lock is None:
return
if self._stop_event is not None:
return
self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._heartbeat_loop,
args=(self._lock, self._stop_event),
daemon=True,
name=f"DbMigrationAutoRenewLock({self._name})",
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()
except LockNotOwnedError:
self._logger.warning(
"DB migration lock is no longer owned during heartbeat; stop renewing. log_context=%s",
self._log_context,
exc_info=True,
)
return
except RedisError:
self._logger.warning(
"Failed to renew DB migration lock due to Redis error; will retry. log_context=%s",
self._log_context,
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while renewing DB migration lock; will retry. log_context=%s",
self._log_context,
exc_info=True,
)
def release_safely(self, *, status: str | None = None) -> None:
"""
Stop heartbeat and release lock. Never raises.
Args:
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
"""
lock = self._lock
if lock is None:
return
self._stop_heartbeat()
# Lock release errors should never mask the real error/exit code.
try:
lock.release()
except LockNotOwnedError:
self._logger.warning(
"DB migration lock not owned on release; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
except RedisError:
self._logger.warning(
"Failed to release DB migration lock due to Redis error; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while releasing DB migration lock; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
finally:
self._acquired = False
self._lock = None
def _stop_heartbeat(self) -> None:
if self._stop_event is None:
return
self._stop_event.set()
if self._thread is not None:
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
join_timeout_seconds = max(
MIN_JOIN_TIMEOUT_SECONDS,
min(MAX_JOIN_TIMEOUT_SECONDS, self._renew_interval_seconds * JOIN_TIMEOUT_MULTIPLIER),
)
self._thread.join(timeout=join_timeout_seconds)
if self._thread.is_alive():
self._logger.warning(
"DB migration lock heartbeat thread did not stop within %.2fs; ignoring. log_context=%s",
join_timeout_seconds,
self._log_context,
)
self._stop_event = None
self._thread = None

View File

@ -8,7 +8,6 @@ Create Date: 2025-12-25 10:39:15.139304
from alembic import op from alembic import op
import models as models import models as models
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '7df29de0f6be' revision = '7df29de0f6be'
@ -20,7 +19,7 @@ depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools', op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False), sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False), sa.Column('quota_limit', sa.BigInteger(), nullable=False),

View File

@ -8,7 +8,6 @@ Create Date: 2026-01-017 11:10:18.079355
from alembic import op from alembic import op
import models as models import models as models
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = 'f9f6d18a37f9' revision = 'f9f6d18a37f9'
@ -20,7 +19,7 @@ depends_on = None
def upgrade(): def upgrade():
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records', op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False), sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False), sa.Column('count', sa.Integer(), nullable=False),
@ -33,17 +32,17 @@ def upgrade():
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False) batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners', op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.JSON(), nullable=False), sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False), sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False), sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False), sa.Column('status', sa.String(length=255), server_default='enabled', nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False), sa.Column('language', sa.String(length=255), server_default='en-US', nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey') sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
) )
op.create_table('trial_apps', op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False), sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False), sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False), sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@ -620,7 +620,7 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"), sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
) )
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False) tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"), sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"), sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
) )
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()")) id = mapped_column(StringUUID, default=lambda: str(uuid4()))
account_id = mapped_column(StringUUID, nullable=False) account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False) app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0) count = mapped_column(sa.Integer, nullable=False, default=0)
@ -660,18 +660,18 @@ class AccountTrialAppRecord(Base):
class ExporleBanner(TypeBase): class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners" __tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),) __table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False) id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False) content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False) link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False) sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column( status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled" sa.String(255), nullable=False, server_default='enabled', default="enabled"
) )
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
) )
language: Mapped[str] = mapped_column( language: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US" String(255), nullable=False, server_default='en-US', default="en-US"
) )
@ -2166,7 +2166,7 @@ class TenantCreditPool(TypeBase):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"), sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
) )
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False) id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial") pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)

View File

@ -1,6 +1,6 @@
[project] [project]
name = "dify-api" name = "dify-api"
version = "1.12.0" version = "1.12.1"
requires-python = ">=3.11,<3.13" requires-python = ">=3.11,<3.13"
dependencies = [ dependencies = [
@ -81,7 +81,7 @@ dependencies = [
"starlette==0.49.1", "starlette==0.49.1",
"tiktoken~=0.9.0", "tiktoken~=0.9.0",
"transformers~=4.56.1", "transformers~=4.56.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1", "unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3", "yarl~=1.18.3",
"webvtt-py~=0.5.1", "webvtt-py~=0.5.1",
"sseclient-py~=1.8.0", "sseclient-py~=1.8.0",

View File

@ -327,6 +327,12 @@ class AccountService:
@staticmethod @staticmethod
def delete_account(account: Account): def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion.""" """Delete account. This method only adds a task to the queue for deletion."""
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
from services.enterprise.account_deletion_sync import sync_account_deletion
sync_account_deletion(account_id=account.id, source="account_deleted")
# Now proceed with async account deletion
delete_account_task.delay(account.id) delete_account_task.delay(account.id)
@staticmethod @staticmethod
@ -1230,6 +1236,11 @@ class TenantService:
if dify_config.BILLING_ENABLED: if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(tenant.id) BillingService.clean_billing_info_cache(tenant.id)
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
@staticmethod @staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account): def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role""" """Update member role"""

View File

@ -0,0 +1,115 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import TenantAccountJoin
logger = logging.getLogger(__name__)
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Queue an account deletion sync task to Redis.
Internal helper function. Do not call directly - use the public functions instead.
Args:
workspace_id: The workspace/tenant ID to sync
member_id: The member/account ID that was removed
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"member_id": member_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
workspace_id,
member_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error(
"Failed to queue account deletion sync for workspace %s, member %s: %s",
workspace_id,
member_id,
str(e),
exc_info=True,
)
# Don't raise - we don't want to fail member deletion if queueing fails
return False
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Sync a single workspace member removal (enterprise only).
Queues a task for the enterprise backend to reassign resources from the removed member.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
workspace_id: The workspace/tenant ID
member_id: The member/account ID that was removed
source: Source of the sync request (e.g., "workspace_member_removed")
Returns:
bool: True if task was queued (or skipped in community), False if queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
def sync_account_deletion(account_id: str, *, source: str) -> bool:
"""
Sync full account deletion across all workspaces (enterprise only).
Fetches all workspace memberships for the account and queues a sync task for each.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
account_id: The account ID being deleted
source: Source of the sync request (e.g., "account_deleted")
Returns:
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
# Fetch all workspaces the account belongs to
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
# Queue sync task for each workspace
success = True
for join in workspace_joins:
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
success = False
return success

View File

@ -4,6 +4,8 @@ from pydantic import BaseModel, Field
from services.enterprise.base import EnterpriseRequest from services.enterprise.base import EnterpriseRequest
ALLOWED_ACCESS_MODES = ["public", "private", "private_all", "sso_verified"]
class WebAppSettings(BaseModel): class WebAppSettings(BaseModel):
access_mode: str = Field( access_mode: str = Field(
@ -123,8 +125,8 @@ class EnterpriseService:
def update_app_access_mode(cls, app_id: str, access_mode: str): def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id: if not app_id:
raise ValueError("app_id must be provided.") raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]: if access_mode not in ALLOWED_ACCESS_MODES:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") raise ValueError(f"access_mode must be one of: {', '.join(ALLOWED_ACCESS_MODES)}")
data = {"appId": app_id, "accessMode": access_mode} data = {"appId": app_id, "accessMode": access_mode}

View File

@ -2,7 +2,10 @@ import json
import logging import logging
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.account import Account
from sqlalchemy import exists, select from sqlalchemy import exists, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -406,20 +409,37 @@ class BuiltinToolManageService:
return {"result": "success"} return {"result": "success"}
@staticmethod @staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None):
""" """
set default provider set default provider
""" """
with Session(db.engine) as session: with Session(db.engine) as session:
# get provider # get provider (verify tenant ownership to prevent IDOR)
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None: if target_provider is None:
raise ValueError("provider not found") raise ValueError("provider not found")
# clear default provider # clear default provider
session.query(BuiltinToolProvider).filter_by( if dify_config.ENTERPRISE_ENABLED:
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True # Enterprise: verify admin permission for tenant-wide operation
).update({"is_default": False}) from models.account import TenantAccountRole
if account is None:
# In enterprise mode, an account context is required to perform permission checks
raise ValueError("Account is required to set default credentials in enterprise mode")
if not TenantAccountRole.is_privileged_role(account.current_role):
raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode")
# Enterprise: clear ALL defaults for this provider in the tenant
# (regardless of user_id, since enterprise credentials may have different user_id)
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, provider=provider, is_default=True
).update({"is_default": False})
else:
# Non-enterprise: only clear defaults for the current user
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider # set new default provider
target_provider.is_default = True target_provider.is_default = True

View File

@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@ -58,5 +57,3 @@ def add_annotation_to_index_task(
) )
except Exception: except Exception:
logger.exception("Build index for annotation failed") logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -5,7 +5,6 @@ import click
from celery import shared_task from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception: except Exception:
logger.exception("Annotation deleted index failed") logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@ -59,5 +58,3 @@ def update_annotation_to_index_task(
) )
except Exception: except Exception:
logger.exception("Build index for annotation failed") logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@ -14,6 +14,9 @@ from models.model import UploadFile
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset") @shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form: if not doc_form:
raise ValueError("doc_form is required") raise ValueError("doc_form is required")
with session_factory.create_session() as session: storage_keys_to_delete: list[str] = []
try: index_node_ids: list[str] = []
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars( segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all() ).all()
# check segment is exist
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor() segment_ids = [segment.id for segment in segments]
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
# Collect image file IDs from segment content
for segment in segments: for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content) image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() total_image_upload_file_ids.extend(image_upload_file_ids)
for image_file in image_files:
try: # Query storage keys for image files
if image_file and image_file.key: if total_image_upload_file_ids:
storage.delete(image_file.key) image_files = session.scalars(
except Exception: select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
logger.exception( ).all()
"Delete image_files failed when storage deleted, \ storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
image_upload_file_is: %s",
image_file.id, # Query storage keys for document files
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
if file_ids: if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files: storage_keys_to_delete.extend([f.key for f in files if f and f.key])
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit() # ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
end_at = time.perf_counter() try:
logger.info( # Fetch dataset in a fresh session to avoid DetachedInstanceError
click.style( with session_factory.create_session() as session:
f"Cleaned documents when documents deleted latency: {end_at - start_at}", dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
fg="green", if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
) )
)
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception: except Exception:
logger.exception("Cleaned documents when documents deleted failed") logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)

View File

@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}" indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
dataset = session.get(Dataset, dataset_id) dataset = session.get(Dataset, dataset_id)
@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
if not upload_file: if not upload_file:
raise ValueError("UploadFile not found.") raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir: dataset_config = {
suffix = Path(upload_file.key).suffix "id": dataset.id,
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore "indexing_technique": dataset.indexing_technique,
storage.download(upload_file.key, file_path) "tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
df = pd.read_csv(file_path) document_config = {
content = [] "id": dataset_document.id,
for _, row in df.iterrows(): "doc_form": dataset_document.doc_form,
if dataset_document.doc_form == "qa_model": "word_count": dataset_document.word_count or 0,
data = {"content": row.iloc[0], "answer": row.iloc[1]} }
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = [] upload_file_key = upload_file.key
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0 except Exception:
if embedding_model: logger.exception("Segments batch created index failed")
tokens_list = embedding_model.get_text_embedding_num_tokens( redis_client.setex(indexing_cache_key, 600, "error")
texts=[segment["content"] for segment in content] return
)
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else: else:
tokens_list = [0] * len(content) data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
for segment, tokens in zip(content, tokens_list): document_segments = []
content = segment["content"] embedding_model = None
doc_id = str(uuid.uuid4()) if dataset_config["indexing_technique"] == "high_quality":
segment_hash = helper.generate_text_hash(content) model_manager = ModelManager()
max_position = ( embedding_model = model_manager.get_model_instance(
session.query(func.max(DocumentSegment.position)) tenant_id=dataset_config["tenant_id"],
.where(DocumentSegment.document_id == dataset_document.id) provider=dataset_config["embedding_model_provider"],
.scalar() model_type=ModelType.TEXT_EMBEDDING,
) model=dataset_config["embedding_model"],
segment_document = DocumentSegment( )
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
session.add(dataset_document) session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) with session_factory.create_session() as session:
session.commit() dataset = session.get(Dataset, dataset_id)
redis_client.setex(indexing_cache_key, 600, "completed") if dataset:
end_at = time.perf_counter() VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
logger.info(
click.style( redis_client.setex(indexing_cache_key, 600, "completed")
f"Segment batch created job: {job_id} latency: {end_at - start_at}", end_at = time.perf_counter()
fg="green", logger.info(
) click.style(
) f"Segment batch created job: {job_id} latency: {end_at - start_at}",
except Exception: fg="green",
logger.exception("Segments batch created index failed") )
redis_client.setex(indexing_cache_key, 600, "error") )

View File

@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
""" """
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id, SegmentAttachmentBinding.document_id == document_id,
) )
).all() ).all()
# check segment is exist
if segments: attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
index_node_ids = [segment.index_node_id for segment in segments] binding_ids = [binding.id for binding, _ in attachments_with_bindings]
index_processor = IndexProcessorFactory(doc_form).init_index_processor() total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean( index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
) )
for segment in segments: total_image_files = []
image_upload_file_ids = get_image_upload_file_ids(segment.content) with session_factory.create_session() as session, session.begin():
image_files = session.scalars( for segment_content in segment_contents:
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) image_upload_file_ids = get_image_upload_file_ids(segment_content)
).all() image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
for image_file in image_files: total_image_files.extend([image_file.key for image_file in image_files])
if image_file is None: image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
continue session.execute(image_file_delete_stmt)
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) with session_factory.create_session() as session, session.begin():
session.execute(image_file_delete_stmt) segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.delete(segment) session.execute(segment_delete_stmt)
session.commit() for image_file_key in total_image_files:
if file_id: try:
file = session.query(UploadFile).where(UploadFile.id == file_id).first() storage.delete(image_file_key)
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception: except Exception:
logger.exception("Cleaned document when document deleted failed") logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
""" """
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green")) logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
raise Exception("Document has no dataset") raise Exception("Document has no dataset")
index_type = dataset.doc_form index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids)) document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt) session.execute(document_delete_stmt)
for document_id in document_ids: for document_id in document_ids:
segments = session.scalars( segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
select(DocumentSegment).where(DocumentSegment.document_id == document_id) total_index_node_ids.extend([segment.index_node_id for segment in segments])
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean( with session_factory.create_session() as session:
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
) if dataset:
segment_ids = [segment.id for segment in segments] index_processor.clean(
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
) )
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed") with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)

View File

@ -3,6 +3,7 @@ import time
import click import click
from celery import shared_task from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@ -67,8 +68,14 @@ def delete_segment_from_index_task(
if segment_attachment_bindings: if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings: segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
session.delete(binding)
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file # delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit() session.commit()

View File

@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
""" """
logger.info(click.style(f"Start sync document: {document_id}", fg="green")) logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session: with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document: if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red")) logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import": if document.data_source_type != "notion_import":
if ( logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
not data_source_info return
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
# Get credentials from datasource provider if (
datasource_provider_service = DatasourceProviderService() not data_source_info
credential = datasource_provider_service.get_datasource_credentials( or "notion_page_id" not in data_source_info
tenant_id=document.tenant_id, or "notion_workspace_id" not in data_source_info
credential_id=credential_id, ):
provider="notion_datasource", raise ValueError("no notion page found")
plugin_id="langgenius/notion_datasource",
)
if not credential: workspace_id = data_source_info["notion_workspace_id"]
logger.error( page_id = data_source_info["notion_page_id"]
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s", page_type = data_source_info["type"]
document_id, page_edited_time = data_source_info["last_edited_time"]
document.tenant_id, credential_id = data_source_info.get("credential_id")
credential_id, tenant_id = document.tenant_id
) index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error" document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now() document.stopped_at = naive_utc_now()
session.commit() return
return
loader = NotionExtractor( loader = NotionExtractor(
notion_workspace_id=workspace_id, notion_workspace_id=workspace_id,
notion_obj_id=page_id, notion_obj_id=page_id,
notion_page_type=page_type, notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"), notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id, tenant_id=tenant_id,
) )
last_edited_time = loader.get_notion_last_edited_time() last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
# check the page is updated logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index try:
try: index_processor = IndexProcessorFactory(index_type).init_index_processor()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() with session_factory.create_session() as session:
if not dataset: dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
raise Exception("Dataset not found") if dataset:
index_type = document.doc_form index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor = IndexProcessorFactory(index_type).init_index_processor() logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
segments = session.scalars( with session_factory.create_session() as session, session.begin():
select(DocumentSegment).where(DocumentSegment.document_id == document_id) document = session.query(Document).filter_by(id=document_id).first()
).all() if not document:
index_node_ids = [segment.index_node_id for segment in segments] logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
# delete from vector index data_source_info = document.data_source_info_dict
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
segment_ids = [segment.id for segment in segments] document.indexing_status = "parsing"
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) document.processing_started_at = naive_utc_now()
session.execute(segment_delete_stmt)
end_at = time.perf_counter() segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
logger.info( session.execute(segment_delete_stmt)
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try: logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
indexing_runner = IndexingRunner()
indexing_runner.run([document]) try:
end_at = time.perf_counter() indexing_runner = IndexingRunner()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) with session_factory.create_session() as session:
except DocumentIsPausedError as ex: document = session.query(Document).filter_by(id=document_id).first()
logger.info(click.style(str(ex), fg="yellow")) if document:
except Exception: indexing_runner.run([document])
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id) end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit() session.commit()
return return
for document_id in document_ids: # Phase 1: Update status to parsing (short transaction)
logger.info(click.style(f"Start process document: {document_id}", fg="green")) with session_factory.create_session() as session, session.begin():
documents = (
document = ( session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() )
)
for document in documents:
if document: if document:
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now() document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document) session.add(document)
session.commit() # Transaction committed and closed
try: # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
indexing_runner = IndexingRunner() has_error = False
indexing_runner.run(documents) try:
end_at = time.perf_counter() indexing_runner = IndexingRunner()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled # Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated) # Re-query dataset to get latest summary_index_setting (in case it was updated)
@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status # expire all session to get latest document's indexing status
session.expire_all() session.expire_all()
# Check each document's indexing status and trigger summary generation if completed # Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it) documents = (
document = ( session.query(Document)
session.query(Document) .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.where(Document.id == document_id, Document.dataset_id == dataset_id) .all()
.first() )
)
for document in documents:
if document: if document:
logger.info( logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id, document.id,
document.indexing_status, document.indexing_status,
document.doc_form, document.doc_form,
document.need_summary, document.need_summary,
@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True and document.need_summary is True
): ):
try: try:
generate_summary_index_task.delay(dataset.id, document_id, None) generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info( logger.info(
"Queued summary index generation task for document %s in dataset %s " "Queued summary index generation task for document %s in dataset %s "
"after indexing completed", "after indexing completed",
document_id, document.id,
dataset.id, dataset.id,
) )
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to queue summary index generation task for document %s", "Failed to queue summary index generation task for document %s",
document_id, document.id,
) )
# Don't fail the entire indexing process if summary task queuing fails # Don't fail the entire indexing process if summary task queuing fails
else: else:
logger.info( logger.info(
"Skipping summary generation for document %s: " "Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s", "status=%s, doc_form=%s, need_summary=%s",
document_id, document.id,
document.indexing_status, document.indexing_status,
document.doc_form, document.doc_form,
document.need_summary, document.need_summary,
) )
else: else:
logger.warning("Document %s not found after indexing", document_id) logger.warning("Document %s not found after indexing", document.id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
else: else:
logger.info( logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id, dataset.id,
dataset.indexing_technique, dataset.indexing_technique,
) )
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue( def _document_indexing_with_tenant_queue(

View File

@ -8,7 +8,6 @@ from sqlalchemy import delete, select
from core.db.session_factory import session_factory from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green")) logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
with session_factory.create_session() as session: with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document: if not document:
@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now() document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
try: if not dataset:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() return
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() clean_success = False
if segments: try:
index_node_ids = [segment.index_node_id for segment in segments] index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index_node_ids:
# delete from vector index index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logger.info( logger.info(
click.style( click.style(
@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
fg="green", fg="green",
) )
) )
except Exception: clean_success = True
logger.exception("Cleaned document when document update data source or process rule failed") except Exception:
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
try: if clean_success:
indexing_runner = IndexingRunner() with session_factory.create_session() as session, session.begin():
indexing_runner.run([document]) segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
end_at = time.perf_counter() session.execute(segment_delete_stmt)
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex: try:
logger.info(click.style(str(ex), fg="yellow")) indexing_runner = IndexingRunner()
except Exception: indexing_runner.run([document])
logger.exception("document_indexing_update_task failed, document_id: %s", document_id) end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str): def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(workflow_archive_log_id: str): def del_workflow_archive_log(session, workflow_archive_log_id: str):
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete( session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False synchronize_session=False
) )
@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0 total_files_deleted = 0
while True: while True:
with session_factory.create_session() as session: with session_factory.create_session() as session, session.begin():
# Get a batch of draft variable IDs along with their file_ids # Get a batch of draft variable IDs along with their file_ids
query_sql = """ query_sql = """
SELECT id, file_id FROM workflow_draft_variables SELECT id, file_id FROM workflow_draft_variables

View File

@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
""" """
from celery import shared_task # type: ignore[import-untyped] from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db from core.db.session_factory import session_factory
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@ -17,6 +16,6 @@ def save_workflow_execution_task(
self, self,
deletions: list[DraftVarFileDeletion], deletions: list[DraftVarFileDeletion],
): ):
with Session(bind=db.engine) as session, session.begin(): with session_factory.create_session() as session, session.begin():
srv = WorkflowDraftVariableService(session=session) srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions) srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@ -10,7 +10,10 @@ from models import Tenant
from models.enums import CreatorUserRole from models.enums import CreatorUserRole
from models.model import App, UploadFile from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch from tasks.remove_app_and_related_data_task import (
_delete_draft_variables,
delete_draft_variables_batch,
)
@pytest.fixture @pytest.fixture
@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data): def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data data = setup_offload_test_data
app_id = data["app"].id app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.return_value = None mock_storage.delete.return_value = None
with session_factory.create_session() as session: with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count() draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count() var_files_before = (
upload_files_before = session.query(UploadFile).count() session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3 assert draft_vars_before == 3
assert var_files_before == 2 assert var_files_before == 2
assert upload_files_before == 2 assert upload_files_before == 2
@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0 assert draft_vars_after == 0
with session_factory.create_session() as session: with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count() var_files_after = (
upload_files_after = session.query(UploadFile).count() session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0 assert var_files_after == 0
assert upload_files_after == 0 assert upload_files_after == 0
@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data): def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data data = setup_offload_test_data
app_id = data["app"].id app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None] mock_storage.delete.side_effect = [Exception("Storage error"), None]
deleted_count = delete_draft_variables_batch(app_id, batch_size=10) deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0 assert draft_vars_after == 0
with session_factory.create_session() as session: with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count() var_files_after = (
upload_files_after = session.query(UploadFile).count() session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0 assert var_files_after == 0
assert upload_files_after == 0 assert upload_files_after == 0
@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
if app2_obj: if app2_obj:
session.delete(app2_obj) session.delete(app2_obj)
session.commit() session.commit()
class TestDeleteDraftVariablesSessionCommit:
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with offload files for session commit tests."""
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
tenant, app = app_and_tenant
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
yield data
with session_factory.create_session() as session:
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@pytest.fixture
def setup_commit_test_data(self, app_and_tenant):
"""Create test data for session commit tests."""
tenant, app = app_and_tenant
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(10):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
yield {
"app": app,
"tenant": tenant,
"variable_ids": variable_ids,
}
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
session.execute(cleanup_query)
session.commit()
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
"""Test that session.begin() is used for automatic transaction management."""
data = setup_commit_test_data
app_id = data["app"].id
# Since session.begin() is used, the transaction is automatically committed
# when the with block exits successfully. We verify this by checking that
# data is actually persisted.
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert deleted_count == 10
assert remaining_count == 0
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
"""Test that data is actually persisted to database after batch deletion with commits."""
data = setup_commit_test_data
app_id = data["app"].id
variable_ids = data["variable_ids"]
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
assert deleted_count == 10
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
)
assert remaining_vars == 0
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
"""Test session behavior when deleting from an empty dataset."""
nonexistent_app_id = str(uuid.uuid4())
# Should not raise any errors and should return 0
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
assert deleted_count == 0
def test_session_commit_with_single_batch(self, setup_commit_test_data):
"""Test that commit happens correctly when all data fits in a single batch."""
data = setup_commit_test_data
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Delete all in a single batch
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
assert deleted_count == 10
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
"""Test that invalid batch size raises ValueError."""
data = setup_commit_test_data
app_id = data["app"].id
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=-1)
@patch("extensions.ext_storage.storage")
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that session commits correctly when cleaning up offload data."""
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
mock_storage.delete.return_value = None
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete variables with offload data
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count == 3
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage cleanup was called
assert mock_storage.delete.call_count == 2

View File

@ -0,0 +1,38 @@
"""
Integration tests for DbMigrationAutoRenewLock using real Redis via TestContainers.
"""
import time
import uuid
import pytest
from extensions.ext_redis import redis_client
from libs.db_migration_lock import DbMigrationAutoRenewLock
@pytest.mark.usefixtures("flask_app_with_containers")
def test_db_migration_lock_renews_ttl_and_releases():
lock_name = f"test:db_migration_auto_renew_lock:{uuid.uuid4().hex}"
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
lock = DbMigrationAutoRenewLock(
redis_client=redis_client,
name=lock_name,
ttl_seconds=1.0,
renew_interval_seconds=0.2,
log_context="test_db_migration_lock",
)
acquired = lock.acquire(blocking=True, blocking_timeout=5)
assert acquired is True
# Wait beyond the base TTL; key should still exist due to renewal.
time.sleep(1.5)
ttl = redis_client.ttl(lock_name)
assert ttl > 0
lock.release_safely(status="successful")
# After release, the key should not exist.
assert redis_client.exists(lock_name) == 0

View File

@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download mock_storage.download.side_effect = mock_download
# Execute the task # Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
batch_create_segment_to_index_task( with pytest.raises(ValueError, match="The CSV file is empty"):
job_id=job_id, batch_create_segment_to_index_task(
upload_file_id=upload_file.id, job_id=job_id,
dataset_id=dataset.id, upload_file_id=upload_file.id,
document_id=document.id, dataset_id=dataset.id,
tenant_id=tenant.id, document_id=document.id,
user_id=account.id, tenant_id=tenant.id,
) user_id=account.id,
)
# Verify error handling # Verify error handling
# Check Redis cache was set to error status # Since exception was raised, no segments should be created
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
from extensions.ext_database import db from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db.session.query(DocumentSegment).all()

View File

@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task(document_ids, dataset.id) clean_notion_document_task(document_ids, dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids)) .filter(DocumentSegment.document_id.in_(document_ids))
@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
== 0 == 0
) )
# Verify index processor was called for each document # Verify index processor was called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == len(document_ids) mock_processor.clean.assert_called_once()
# This test successfully verifies: # This test successfully verifies:
# 1. Document records are properly deleted from the database # 1. Document records are properly deleted from the database
@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4()) non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())] document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset # Execute cleanup task with non-existent dataset - expect exception
clean_notion_document_task(document_ids, non_existent_dataset_id) with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor was not called # Verify that the index processor factory was not used
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
mock_processor.clean.assert_not_called()
def test_clean_notion_document_task_empty_document_list( def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list # Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id) clean_notion_document_task([], dataset.id)
# Verify that the index processor was not called # Verify that the index processor was called once with empty node list
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called() assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types( def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types. # Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations. # The task properly handles various index types and document configurations.
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id) .filter(DocumentSegment.document_id == document.id)
@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task([document.id], dataset.id) clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0 == 0
@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id) clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents and segments are deleted # Verify only specified documents' segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean)) .filter(DocumentSegment.document_id.in_(documents_to_clean))
@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit() db_session_with_containers.commit()
# Mock index processor to raise an exception # Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error") mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - it should handle the exception gracefully # Execute cleanup task - current implementation propagates the exception
clean_notion_document_task([document.id], dataset.id) with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability. # Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully. # Even with external service errors, the database operations complete successfully.
@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents] all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id) clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted # Verify all segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0 == 0
@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id) clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents from target dataset are deleted # Verify only documents' segments from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment) db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id) .filter(DocumentSegment.document_id == target_document.id)
@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents] all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id) clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted regardless of status # Verify all segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0 == 0
@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task # Execute cleanup task
clean_notion_document_task([document.id], dataset.id) clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted # Verify segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
assert ( assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count() db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0 == 0

View File

@ -0,0 +1,182 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_update_task import document_indexing_update_task
class TestDocumentIndexingUpdateTask:
@pytest.fixture
def mock_external_dependencies(self):
"""Patch external collaborators used by the update task.
- IndexProcessorFactory.init_index_processor().clean(...)
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
yield {
"factory": mock_factory,
"processor": processor_instance,
"runner": mock_runner,
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Dataset and document
dataset = Dataset(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
document = Document(
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Segments
node_ids = []
for i in range(segment_count):
node_id = f"node-{i + 1}"
seg = DocumentSegment(
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=32),
answer=None,
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
created_by=account.id,
)
db_session_with_containers.add(seg)
node_ids.append(node_id)
db_session_with_containers.commit()
# Refresh to ensure ORM state
db_session_with_containers.refresh(dataset)
db_session_with_containers.refresh(document)
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining == 0
# Assert index processor clean was called with expected args
clean_call = mock_external_dependencies["processor"].clean.call_args
assert clean_call is not None
args, kwargs = clean_call
# args[0] is a Dataset instance (from another session) — validate by id
assert getattr(args[0], "id", None) == dataset.id
# args[1] should contain our node_ids
assert set(args[1]) == set(node_ids)
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
# Assert indexing runner invoked with the updated document
run_call = mock_external_dependencies["runner_instance"].run.call_args
assert run_call is not None
run_docs = run_call[0][0]
assert len(run_docs) == 1
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Indexing should still be triggered
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
# Neither processor nor runner should be called
mock_external_dependencies["processor"].clean.assert_not_called()
mock_external_dependencies["runner_instance"].run.assert_not_called()

View File

@ -0,0 +1,146 @@
import sys
import threading
import types
from unittest.mock import MagicMock
import commands
from libs.db_migration_lock import LockNotOwnedError, RedisError
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
module = types.ModuleType("flask_migrate")
module.upgrade = upgrade_impl
monkeypatch.setitem(sys.modules, "flask_migrate", module)
def _invoke_upgrade_db() -> int:
try:
commands.upgrade_db.callback()
except SystemExit as e:
return int(e.code or 0)
return 0
def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
lock = MagicMock()
lock.acquire.return_value = False
commands.redis_client.lock.return_value = lock
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration skipped" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_not_called()
def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
def _upgrade():
raise RuntimeError("boom")
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 1
assert "Database migration failed: boom" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
_install_fake_flask_migrate(monkeypatch, lambda: None)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration successful!" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
"""
Ensure the lock is renewed while migrations are running, so the base TTL can stay short.
"""
# Use a small TTL so the heartbeat interval triggers quickly.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
renewed = threading.Event()
def _reacquire():
renewed.set()
return True
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert renewed.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1
def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
# Use a small TTL so heartbeat runs during the upgrade call.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
attempted = threading.Event()
def _reacquire():
attempted.set()
raise RedisError("simulated")
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert attempted.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1

View File

@ -25,15 +25,19 @@ class TestMessageCycleManagerOptimization:
task_state = Mock() task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state) return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager): def test_get_message_event_type_with_assistant_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files.""" """Test get_message_event_type returns MESSAGE_FILE when message has assistant-generated files.
This ensures that AI-generated images (belongs_to='assistant') trigger the MESSAGE_FILE event,
allowing the frontend to properly display generated image files with url field.
"""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file # Setup mock session and message file
mock_session = Mock() mock_session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock() mock_message_file = Mock()
# Current implementation uses session.scalar(select(...)) mock_message_file.belongs_to = "assistant"
mock_session.scalar.return_value = mock_message_file mock_session.scalar.return_value = mock_message_file
# Execute # Execute
@ -44,6 +48,31 @@ class TestMessageCycleManagerOptimization:
assert result == StreamEvent.MESSAGE_FILE assert result == StreamEvent.MESSAGE_FILE
mock_session.scalar.assert_called_once() mock_session.scalar.assert_called_once()
def test_get_message_event_type_with_user_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message only has user-uploaded files.
This is a regression test for the issue where user-uploaded images (belongs_to='user')
caused the LLM text response to be incorrectly tagged with MESSAGE_FILE event,
resulting in broken images in the chat UI. The query filters for belongs_to='assistant',
so when only user files exist, the database query returns None, resulting in MESSAGE event type.
"""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# When querying for assistant files with only user files present, return None
# (simulates database query with belongs_to='assistant' filter returning no results)
mock_session.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager): def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files.""" """Test get_message_event_type returns MESSAGE when message has no files."""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory: with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
@ -69,7 +98,7 @@ class TestMessageCycleManagerOptimization:
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock() mock_message_file = Mock()
# Current implementation uses session.scalar(select(...)) mock_message_file.belongs_to = "assistant"
mock_session.scalar.return_value = mock_message_file mock_session.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response # Execute: compute event type once, then pass to message_to_stream_response

View File

@ -4,7 +4,7 @@ from typing import Any
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from hypothesis import given, settings from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
) )
@settings(max_examples=50) @settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(_scalar_value()) @given(_scalar_value())
def test_build_segment_and_extract_values_for_scalar_types(value): def test_build_segment_and_extract_values_for_scalar_types(value):
seg = variable_factory.build_segment(value) seg = variable_factory.build_segment(value)
@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
assert seg.value == value assert seg.value == value
@settings(max_examples=50) @settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
@given(values=st.lists(_scalar_value(), max_size=20)) @given(values=st.lists(_scalar_value(), max_size=20))
def test_build_segment_and_extract_values_for_array_types(values): def test_build_segment_and_extract_values_for_array_types(values):
seg = variable_factory.build_segment(values) seg = variable_factory.build_segment(values)

View File

@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session(): def mock_db_session():
"""Mock database session via session_factory.create_session().""" """Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf: with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock() sessions = [] # Track all created sessions
# Ensure tests that expect session.close() to be called can observe it via the context manager # Shared mock data that all sessions will access
session.close = MagicMock() shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def _exit_side_effect(*args, **kwargs): def create_session_side_effect():
session.close() session = MagicMock()
session.close = MagicMock()
cm.__exit__.side_effect = _exit_side_effect # Track commit calls
mock_sf.create_session.return_value = cm commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
query = MagicMock() def _exit_side_effect(*args, **kwargs):
session.query.return_value = query session.close()
query.where.return_value = query
yield session cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
@pytest.fixture @pytest.fixture
@ -252,18 +356,9 @@ class TestTaskEnqueuing:
use the deprecated function. use the deprecated function.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -304,21 +399,9 @@ class TestBatchProcessing:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
# Create an iterator for documents mock_db_session._shared_data["documents"] = mock_documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -357,19 +440,9 @@ class TestBatchProcessing:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@ -407,19 +480,9 @@ class TestBatchProcessing:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@ -444,7 +507,10 @@ class TestBatchProcessing:
""" """
# Arrange # Arrange
document_ids = [] document_ids = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -482,19 +548,9 @@ class TestProgressTracking:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -528,19 +584,9 @@ class TestProgressTracking:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -635,19 +681,9 @@ class TestErrorHandling:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set up to trigger vector space limit error # Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@ -674,17 +710,9 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task. Errors during indexing should be caught and logged, but not crash the task.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed") mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@ -708,17 +736,9 @@ class TestErrorHandling:
but not treated as a failure. but not treated as a failure.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise DocumentIsPausedError # Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@ -853,17 +873,9 @@ class TestTaskCancellation:
Session cleanup should happen in finally block. Session cleanup should happen in finally block.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -883,17 +895,9 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur. Session cleanup should happen even when errors occur.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error") mock_indexing_runner.run.side_effect = Exception("Test error")
@ -962,6 +966,7 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)] document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing) # Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = [] mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document) doc = MagicMock(spec=Document)
@ -971,21 +976,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
# Create iterator that returns None for missing document mock_db_session._shared_data["documents"] = mock_documents
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space exactly at limit # Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Billing disabled - limits should not be checked # Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False mock_feature_service.get_features.return_value.billing.enabled = False
@ -1273,19 +1246,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks) # Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None] mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1321,19 +1284,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks) # Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None] mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1415,17 +1368,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting" mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = [mock_document]
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1465,17 +1410,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting" mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = [mock_document]
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1555,19 +1492,9 @@ class TestEdgeCases:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space limit to 0 (unlimited) # Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@ -1612,19 +1539,9 @@ class TestEdgeCases:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set negative vector space limit # Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Configure billing with sufficient limits # Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@ -1826,19 +1733,9 @@ class TestRobustness:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@ -1866,7 +1763,7 @@ class TestRobustness:
- No exceptions occur - No exceptions occur
Expected behavior: Expected behavior:
- Database session is closed - All database sessions are closed
- No connection leaks - No connection leaks
""" """
# Arrange # Arrange
@ -1879,19 +1776,9 @@ class TestRobustness:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@ -1899,10 +1786,11 @@ class TestRobustness:
# Act # Act
_document_indexing(dataset_id, document_ids) _document_indexing(dataset_id, document_ids)
# Assert # Assert - All created sessions should be closed
assert mock_db_session.close.called # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
# Verify close is called exactly once assert len(mock_db_session.all_sessions) >= 1
assert mock_db_session.close.call_count == 1 for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
""" """

View File

@ -109,25 +109,87 @@ def mock_document_segments(document_id):
@pytest.fixture @pytest.fixture
def mock_db_session(): def mock_db_session():
"""Mock database session via session_factory.create_session().""" """Mock database session via session_factory.create_session().
After session split refactor, the code calls create_session() multiple times.
This fixture creates shared query mocks so all sessions use the same
query configuration, simulating database persistence across sessions.
The fixture automatically converts side_effect to cycle to prevent StopIteration.
Tests configure mocks the same way as before, but behind the scenes the values
are cycled infinitely for all sessions.
"""
from itertools import cycle
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf: with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock() sessions = []
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs): # Shared query mocks - all sessions use these
session.close() shared_query = MagicMock()
shared_filter_by = MagicMock()
shared_scalars_result = MagicMock()
cm.__exit__.side_effect = _exit_side_effect # Create custom first mock that auto-cycles side_effect
mock_sf.create_session.return_value = cm class CyclicMock(MagicMock):
def __setattr__(self, name, value):
if name == "side_effect" and value is not None:
# Convert list/tuple to infinite cycle
if isinstance(value, (list, tuple)):
value = cycle(value)
super().__setattr__(name, value)
query = MagicMock() shared_query.where.return_value.first = CyclicMock()
session.query.return_value = query shared_filter_by.first = CyclicMock()
query.where.return_value = query
session.scalars.return_value = MagicMock() def _create_session():
yield session """Create a new mock session for each create_session() call."""
session = MagicMock()
session.close = MagicMock()
session.commit = MagicMock()
# Mock session.begin() context manager
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def _begin_exit_side_effect(exc_type, exc, tb):
# commit on success
if exc_type is None:
session.commit()
# return False to propagate exceptions
return False
begin_cm.__exit__.side_effect = _begin_exit_side_effect
session.begin.return_value = begin_cm
# Mock create_session() context manager
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(exc_type, exc, tb):
session.close()
return False
cm.__exit__.side_effect = _exit_side_effect
# All sessions use the same shared query mocks
session.query.return_value = shared_query
shared_query.where.return_value = shared_query
shared_query.filter_by.return_value = shared_filter_by
session.scalars.return_value = shared_scalars_result
sessions.append(session)
# Attach helpers on the first created session for assertions across all sessions
if len(sessions) == 1:
session.get_all_sessions = lambda: sessions
session.any_close_called = lambda: any(s.close.called for s in sessions)
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
return cm
mock_sf.create_session.side_effect = _create_session
# Create first session and return it
_create_session()
yield sessions[0]
@pytest.fixture @pytest.fixture
@ -186,8 +248,8 @@ class TestDocumentIndexingSyncTask:
# Act # Act
document_indexing_sync_task(dataset_id, document_id) document_indexing_sync_task(dataset_id, document_id)
# Assert # Assert - at least one session should have been closed
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id): def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
"""Test that task raises error when notion_workspace_id is missing.""" """Test that task raises error when notion_workspace_id is missing."""
@ -230,6 +292,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task handles missing credentials by updating document status.""" """Test that task handles missing credentials by updating document status."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_datasource_provider_service.get_datasource_credentials.return_value = None mock_datasource_provider_service.get_datasource_credentials.return_value = None
# Act # Act
@ -239,8 +302,8 @@ class TestDocumentIndexingSyncTask:
assert mock_document.indexing_status == "error" assert mock_document.indexing_status == "error"
assert "Datasource credential not found" in mock_document.error assert "Datasource credential not found" in mock_document.error
assert mock_document.stopped_at is not None assert mock_document.stopped_at is not None
mock_db_session.commit.assert_called() assert mock_db_session.any_commit_called()
mock_db_session.close.assert_called() assert mock_db_session.any_close_called()
def test_page_not_updated( def test_page_not_updated(
self, self,
@ -254,6 +317,7 @@ class TestDocumentIndexingSyncTask:
"""Test that task does nothing when page has not been updated.""" """Test that task does nothing when page has not been updated."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Return same time as stored in document # Return same time as stored in document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
@ -263,8 +327,8 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Document status should remain unchanged # Document status should remain unchanged
assert mock_document.indexing_status == "completed" assert mock_document.indexing_status == "completed"
# Session should still be closed via context manager teardown # At least one session should have been closed via context manager teardown
assert mock_db_session.close.called assert mock_db_session.any_close_called()
def test_successful_sync_when_page_updated( def test_successful_sync_when_page_updated(
self, self,
@ -281,7 +345,20 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test successful sync flow when Notion page has been updated.""" """Test successful sync flow when Notion page has been updated."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] # Set exact sequence of returns across calls to `.first()`:
# 1) document (initial fetch)
# 2) dataset (pre-check)
# 3) dataset (cleaning phase)
# 4) document (pre-indexing update)
# 5) document (indexing runner fetch)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# NotionExtractor returns updated time # NotionExtractor returns updated time
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
@ -299,28 +376,40 @@ class TestDocumentIndexingSyncTask:
mock_processor.clean.assert_called_once() mock_processor.clean.assert_called_once()
# Verify segments were deleted from database in batch (DELETE FROM document_segments) # Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list] # Aggregate execute calls across all created sessions
execute_sqls = []
for s in mock_db_session.get_all_sessions():
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls) assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called # Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document]) mock_indexing_runner.run.assert_called_once_with([mock_document])
# Verify session operations # Verify session operations (across any created session)
assert mock_db_session.commit.called assert mock_db_session.any_commit_called()
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_dataset_not_found_during_cleaning( def test_dataset_not_found_during_cleaning(
self, self,
mock_db_session, mock_db_session,
mock_datasource_provider_service, mock_datasource_provider_service,
mock_notion_extractor, mock_notion_extractor,
mock_indexing_runner,
mock_document, mock_document,
dataset_id, dataset_id,
document_id, document_id,
): ):
"""Test that task handles dataset not found during cleaning phase.""" """Test that task handles dataset not found during cleaning phase."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None] # Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
None,
mock_document,
mock_document,
]
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act # Act
@ -329,8 +418,8 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Document should still be set to parsing # Document should still be set to parsing
assert mock_document.indexing_status == "parsing" assert mock_document.indexing_status == "parsing"
# Session should be closed after error # At least one session should be closed after error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_cleaning_error_continues_to_indexing( def test_cleaning_error_continues_to_indexing(
self, self,
@ -346,8 +435,14 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that indexing continues even if cleaning fails.""" """Test that indexing continues even if cleaning fails."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
# Make the cleaning step fail but not the segment fetch
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
processor.clean.side_effect = Exception("Cleaning error")
mock_db_session.scalars.return_value.all.return_value = []
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
# Act # Act
@ -356,7 +451,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Indexing should still be attempted despite cleaning error # Indexing should still be attempted despite cleaning error
mock_indexing_runner.run.assert_called_once_with([mock_document]) mock_indexing_runner.run.assert_called_once_with([mock_document])
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_indexing_runner_document_paused_error( def test_indexing_runner_document_paused_error(
self, self,
@ -373,7 +468,10 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that DocumentIsPausedError is handled gracefully.""" """Test that DocumentIsPausedError is handled gracefully."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused") mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
@ -383,7 +481,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Session should be closed after handling error # Session should be closed after handling error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_indexing_runner_general_error( def test_indexing_runner_general_error(
self, self,
@ -400,7 +498,10 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that general exceptions during indexing are handled.""" """Test that general exceptions during indexing are handled."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] from itertools import cycle
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
mock_indexing_runner.run.side_effect = Exception("Indexing error") mock_indexing_runner.run.side_effect = Exception("Indexing error")
@ -410,7 +511,7 @@ class TestDocumentIndexingSyncTask:
# Assert # Assert
# Session should be closed after error # Session should be closed after error
mock_db_session.close.assert_called_once() assert mock_db_session.any_close_called()
def test_notion_extractor_initialized_with_correct_params( def test_notion_extractor_initialized_with_correct_params(
self, self,
@ -517,7 +618,14 @@ class TestDocumentIndexingSyncTask:
): ):
"""Test that index processor clean is called with correct parameters.""" """Test that index processor clean is called with correct parameters."""
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset] # Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
mock_db_session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
mock_document,
mock_document,
]
mock_db_session.scalars.return_value.all.return_value = mock_document_segments mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z" mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"

View File

@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
mock_query.where.return_value = mock_delete_query mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query mock_db.session.query.return_value = mock_query
delete_func("log-1") delete_func(mock_db.session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog) mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once() mock_query.where.assert_called_once()

19
api/uv.lock generated
View File

@ -1368,7 +1368,7 @@ wheels = [
[[package]] [[package]]
name = "dify-api" name = "dify-api"
version = "1.12.0" version = "1.12.1"
source = { virtual = "." } source = { virtual = "." }
dependencies = [ dependencies = [
{ name = "aliyun-log-python-sdk" }, { name = "aliyun-log-python-sdk" },
@ -1653,7 +1653,7 @@ requires-dist = [
{ name = "starlette", specifier = "==0.49.1" }, { name = "starlette", specifier = "==0.49.1" },
{ name = "tiktoken", specifier = "~=0.9.0" }, { name = "tiktoken", specifier = "~=0.9.0" },
{ name = "transformers", specifier = "~=4.56.1" }, { name = "transformers", specifier = "~=4.56.1" },
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
{ name = "weave", specifier = ">=0.52.16" }, { name = "weave", specifier = ">=0.52.16" },
{ name = "weaviate-client", specifier = "==4.17.0" }, { name = "weaviate-client", specifier = "==4.17.0" },
{ name = "webvtt-py", specifier = "~=0.5.1" }, { name = "webvtt-py", specifier = "~=0.5.1" },
@ -4433,15 +4433,15 @@ wheels = [
[[package]] [[package]]
name = "pdfminer-six" name = "pdfminer-six"
version = "20251230" version = "20260107"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "charset-normalizer" }, { name = "charset-normalizer" },
{ name = "cryptography" }, { name = "cryptography" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" } sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" }, { url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" },
] ]
[[package]] [[package]]
@ -6814,12 +6814,12 @@ wheels = [
[[package]] [[package]]
name = "unstructured" name = "unstructured"
version = "0.16.25" version = "0.18.31"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "backoff" }, { name = "backoff" },
{ name = "beautifulsoup4" }, { name = "beautifulsoup4" },
{ name = "chardet" }, { name = "charset-normalizer" },
{ name = "dataclasses-json" }, { name = "dataclasses-json" },
{ name = "emoji" }, { name = "emoji" },
{ name = "filetype" }, { name = "filetype" },
@ -6827,6 +6827,7 @@ dependencies = [
{ name = "langdetect" }, { name = "langdetect" },
{ name = "lxml" }, { name = "lxml" },
{ name = "nltk" }, { name = "nltk" },
{ name = "numba" },
{ name = "numpy" }, { name = "numpy" },
{ name = "psutil" }, { name = "psutil" },
{ name = "python-iso639" }, { name = "python-iso639" },
@ -6839,9 +6840,9 @@ dependencies = [
{ name = "unstructured-client" }, { name = "unstructured-client" },
{ name = "wrapt" }, { name = "wrapt" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" } sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" }, { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
] ]
[package.optional-dependencies] [package.optional-dependencies]

View File

@ -21,7 +21,7 @@ services:
# API service # API service
api: api:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -63,7 +63,7 @@ services:
# worker service # worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.) # The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker: worker:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -102,7 +102,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -132,7 +132,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.12.0 image: langgenius/dify-web:1.12.1
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -707,7 +707,7 @@ services:
# API service # API service
api: api:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -749,7 +749,7 @@ services:
# worker service # worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.) # The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker: worker:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -788,7 +788,7 @@ services:
# worker_beat service # worker_beat service
# Celery beat for scheduling periodic tasks. # Celery beat for scheduling periodic tasks.
worker_beat: worker_beat:
image: langgenius/dify-api:1.12.0 image: langgenius/dify-api:1.12.1
restart: always restart: always
environment: environment:
# Use the shared environment variables. # Use the shared environment variables.
@ -818,7 +818,7 @@ services:
# Frontend web application. # Frontend web application.
web: web:
image: langgenius/dify-web:1.12.0 image: langgenius/dify-web:1.12.1
restart: always restart: always
environment: environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-} CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@ -109,6 +109,7 @@ const AgentTools: FC = () => {
tool_parameters: paramsWithDefaultValue, tool_parameters: paramsWithDefaultValue,
notAuthor: !tool.is_team_authorization, notAuthor: !tool.is_team_authorization,
enabled: true, enabled: true,
type: tool.provider_type as CollectionType,
} }
} }
const handleSelectTool = (tool: ToolDefaultValue) => { const handleSelectTool = (tool: ToolDefaultValue) => {

View File

@ -19,19 +19,21 @@ import { useBoolean, useSessionStorageState } from 'ahooks'
import * as React from 'react' import * as React from 'react'
import { useCallback, useEffect, useState } from 'react' import { useCallback, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useShallow } from 'zustand/react/shallow'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm' import Confirm from '@/app/components/base/confirm'
import { Generator } from '@/app/components/base/icons/src/vender/other' import { Generator } from '@/app/components/base/icons/src/vender/other'
import Loading from '@/app/components/base/loading'
import Loading from '@/app/components/base/loading'
import Modal from '@/app/components/base/modal' import Modal from '@/app/components/base/modal'
import Toast from '@/app/components/base/toast' import Toast from '@/app/components/base/toast'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug' import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug'
import { useGenerateRuleTemplate } from '@/service/use-apps' import { useGenerateRuleTemplate } from '@/service/use-apps'
import { useStore } from '../../../store'
import IdeaOutput from './idea-output' import IdeaOutput from './idea-output'
import InstructionEditorInBasic from './instruction-editor' import InstructionEditorInBasic from './instruction-editor'
import InstructionEditorInWorkflow from './instruction-editor-in-workflow' import InstructionEditorInWorkflow from './instruction-editor-in-workflow'
@ -83,6 +85,9 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
onFinished, onFinished,
}) => { }) => {
const { t } = useTranslation() const { t } = useTranslation()
const { appDetail } = useStore(useShallow(state => ({
appDetail: state.appDetail,
})))
const localModel = localStorage.getItem('auto-gen-model') const localModel = localStorage.getItem('auto-gen-model')
? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model
: null : null
@ -235,6 +240,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
instruction, instruction,
model_config: model, model_config: model,
no_variable: false, no_variable: false,
app_id: appDetail?.id,
}) })
apiRes = { apiRes = {
...res, ...res,
@ -256,6 +262,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
instruction, instruction,
ideal_output: ideaOutput, ideal_output: ideaOutput,
model_config: model, model_config: model,
app_id: appDetail?.id,
}) })
apiRes = res apiRes = res
if (error) { if (error) {

View File

@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
</div> </div>
))} ))}
{ {
showSummaryIndexSetting && ( showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3"> <div className="mt-3">
<SummaryIndexSetting <SummaryIndexSetting
entry="create-document" entry="create-document"

View File

@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge' import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card' import RadioCard from '@/app/components/base/radio-card'
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting' import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets' import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg' import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg' import Note from '../../assets/note-mod.svg'
@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
</div> </div>
))} ))}
{ {
showSummaryIndexSetting && ( showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3"> <div className="mt-3">
<SummaryIndexSetting <SummaryIndexSetting
entry="create-document" entry="create-document"

View File

@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch' import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { DataSourceType, DocumentActionType } from '@/models/datasets' import { DataSourceType, DocumentActionType } from '@/models/datasets'
import { import {
useDocumentArchive, useDocumentArchive,
@ -263,10 +264,14 @@ const Operations = ({
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span> <span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
</div> </div>
)} )}
<div className={s.actionItem} onClick={() => onOperate('summary')}> {
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" /> IS_CE_EDITION && (
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span> <div className={s.actionItem} onClick={() => onOperate('summary')}>
</div> <SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
</div>
)
}
<Divider className="my-1" /> <Divider className="my-1" />
</> </>
)} )}

View File

@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm' import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge' import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import { IS_CE_EDITION } from '@/config'
import { cn } from '@/utils/classnames' import { cn } from '@/utils/classnames'
const i18nPrefix = 'batchAction' const i18nPrefix = 'batchAction'
@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span> <span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
</Button> </Button>
)} )}
{onBatchSummary && ( {onBatchSummary && IS_CE_EDITION && (
<Button <Button
variant="ghost" variant="ghost"
className="gap-x-0.5 px-3" className="gap-x-0.5 px-3"

View File

@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks' import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector' import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { IS_CE_EDITION } from '@/config'
import { useSelector as useAppContextWithSelector } from '@/context/app-context' import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail' import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n' import { useDocLink } from '@/context/i18n'
@ -359,7 +360,7 @@ const Form = () => {
{ {
indexMethod === IndexingType.QUALIFIED indexMethod === IndexingType.QUALIFIED
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode) && [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
&& ( && IS_CE_EDITION && (
<> <>
<Divider <Divider
type="horizontal" type="horizontal"

View File

@ -104,7 +104,7 @@ const MembersPage = () => {
<UpgradeBtn className="mr-2" loc="member-invite" /> <UpgradeBtn className="mr-2" loc="member-invite" />
)} )}
<div className="shrink-0"> <div className="shrink-0">
<InviteButton disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)} /> {isCurrentWorkspaceManager && <InviteButton disabled={isMemberFull} onClick={() => setInviteModalVisible(true)} />}
</div> </div>
</div> </div>
<div className="overflow-visible lg:overflow-visible"> <div className="overflow-visible lg:overflow-visible">

View File

@ -129,6 +129,7 @@ export const useToolSelectorState = ({
extra: { extra: {
description: tool.tool_description, description: tool.tool_description,
}, },
type: tool.provider_type,
} }
}, []) }, [])

View File

@ -87,6 +87,7 @@ export type ToolValue = {
enabled?: boolean enabled?: boolean
extra?: { description?: string } & Record<string, unknown> extra?: { description?: string } & Record<string, unknown>
credential_id?: string credential_id?: string
type?: string
} }
export type DataSourceItem = { export type DataSourceItem = {

View File

@ -18,6 +18,7 @@ import {
Group, Group,
} from '@/app/components/workflow/nodes/_base/components/layout' } from '@/app/components/workflow/nodes/_base/components/layout'
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
import { IS_CE_EDITION } from '@/config'
import Split from '../_base/components/split' import Split from '../_base/components/split'
import ChunkStructure from './components/chunk-structure' import ChunkStructure from './components/chunk-structure'
import EmbeddingModel from './components/embedding-model' import EmbeddingModel from './components/embedding-model'
@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
{ {
data.indexing_technique === IndexMethodEnum.QUALIFIED data.indexing_technique === IndexMethodEnum.QUALIFIED
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure) && [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
&& ( && IS_CE_EDITION && (
<> <>
<SummaryIndexSetting <SummaryIndexSetting
summaryIndexSetting={data.summary_index_setting} summaryIndexSetting={data.summary_index_setting}

View File

@ -1,7 +1,7 @@
{ {
"name": "dify-web", "name": "dify-web",
"type": "module", "type": "module",
"version": "1.12.0", "version": "1.12.1",
"private": true, "private": true,
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a", "packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": { "imports": {

2205
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@ -9,7 +9,6 @@ import type {
} from '@/types/workflow' } from '@/types/workflow'
import { get, post } from './base' import { get, post } from './base'
import { getFlowPrefix } from './utils' import { getFlowPrefix } from './utils'
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
export const fetchWorkflowDraft = (url: string) => { export const fetchWorkflowDraft = (url: string) => {
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse> return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
url: string url: string
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'> params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
}) => { }) => {
const sanitized = sanitizeWorkflowDraftPayload(params) return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
} }
export const fetchNodesDefaultConfigs = (url: string) => { export const fetchNodesDefaultConfigs = (url: string) => {