mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 17:38:04 +08:00
Merge origin/release/e-1.12.1 into 1.12.1-otel-ee
Sync enterprise 1.12.1 changes: - feat: implement heartbeat mechanism for database migration lock - refactor: replace AutoRenewRedisLock with DbMigrationAutoRenewLock - fix: improve logging for database migration lock release - fix: make flask upgrade-db fail on error - fix: include sso_verified in access_mode validation - fix: inherit web app permission from original app - fix: make e-1.12.1 enterprise migrations database-agnostic - fix: get_message_event_type return wrong message type - refactor: document_indexing_sync_task split db session - fix: trigger output schema miss - test: remove unrelated enterprise service test Conflict resolution: - Combined OTEL telemetry imports with tool signature import in easy_ui_based_generate_task_pipeline.py
This commit is contained in:
@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
|
|||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from extensions.storage.opendal_storage import OpenDALStorage
|
from extensions.storage.opendal_storage import OpenDALStorage
|
||||||
from extensions.storage.storage_type import StorageType
|
from extensions.storage.storage_type import StorageType
|
||||||
|
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||||
from libs.helper import email as email_validate
|
from libs.helper import email as email_validate
|
||||||
from libs.password import hash_password, password_pattern, valid_password
|
from libs.password import hash_password, password_pattern, valid_password
|
||||||
from libs.rsa import generate_key_pair
|
from libs.rsa import generate_key_pair
|
||||||
@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||||
|
|
||||||
|
|
||||||
@click.command("reset-password", help="Reset the account password.")
|
@click.command("reset-password", help="Reset the account password.")
|
||||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||||
@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
|||||||
@click.command("upgrade-db", help="Upgrade the database")
|
@click.command("upgrade-db", help="Upgrade the database")
|
||||||
def upgrade_db():
|
def upgrade_db():
|
||||||
click.echo("Preparing database migration...")
|
click.echo("Preparing database migration...")
|
||||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
lock = DbMigrationAutoRenewLock(
|
||||||
|
redis_client=redis_client,
|
||||||
|
name="db_upgrade_lock",
|
||||||
|
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||||
|
logger=logger,
|
||||||
|
log_context="db_migration",
|
||||||
|
)
|
||||||
if lock.acquire(blocking=False):
|
if lock.acquire(blocking=False):
|
||||||
|
migration_succeeded = False
|
||||||
try:
|
try:
|
||||||
click.echo(click.style("Starting database migration.", fg="green"))
|
click.echo(click.style("Starting database migration.", fg="green"))
|
||||||
|
|
||||||
@ -737,12 +747,16 @@ def upgrade_db():
|
|||||||
|
|
||||||
flask_migrate.upgrade()
|
flask_migrate.upgrade()
|
||||||
|
|
||||||
|
migration_succeeded = True
|
||||||
click.echo(click.style("Database migration successful!", fg="green"))
|
click.echo(click.style("Database migration successful!", fg="green"))
|
||||||
|
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Failed to execute database migration")
|
logger.exception("Failed to execute database migration")
|
||||||
|
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||||
|
raise SystemExit(1)
|
||||||
finally:
|
finally:
|
||||||
lock.release()
|
status = "successful" if migration_succeeded else "failed"
|
||||||
|
lock.release_safely(status=status)
|
||||||
else:
|
else:
|
||||||
click.echo("Database migration skipped")
|
click.echo("Database migration skipped")
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Literal, TypeAlias
|
from typing import Any, Literal, TypeAlias
|
||||||
@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
|||||||
|
|
||||||
register_enum_models(console_ns, IconType)
|
register_enum_models(console_ns, IconType)
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AppListQuery(BaseModel):
|
class AppListQuery(BaseModel):
|
||||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||||
@ -499,6 +502,7 @@ class AppListApi(Resource):
|
|||||||
select(Workflow).where(
|
select(Workflow).where(
|
||||||
Workflow.version == Workflow.VERSION_DRAFT,
|
Workflow.version == Workflow.VERSION_DRAFT,
|
||||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||||
|
Workflow.tenant_id == current_tenant_id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
.scalars()
|
.scalars()
|
||||||
@ -510,12 +514,14 @@ class AppListApi(Resource):
|
|||||||
NodeType.TRIGGER_PLUGIN,
|
NodeType.TRIGGER_PLUGIN,
|
||||||
}
|
}
|
||||||
for workflow in draft_workflows:
|
for workflow in draft_workflows:
|
||||||
|
node_id = None
|
||||||
try:
|
try:
|
||||||
for _, node_data in workflow.walk_nodes():
|
for node_id, node_data in workflow.walk_nodes():
|
||||||
if node_data.get("type") in trigger_node_types:
|
if node_data.get("type") in trigger_node_types:
|
||||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||||
break
|
break
|
||||||
except Exception:
|
except Exception:
|
||||||
|
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for app in app_pagination.items:
|
for app in app_pagination.items:
|
||||||
@ -654,6 +660,19 @@ class AppCopyApi(Resource):
|
|||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
# Inherit web app permission from original app
|
||||||
|
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||||
|
try:
|
||||||
|
# Get the original app's access mode
|
||||||
|
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
|
||||||
|
access_mode = original_settings.access_mode
|
||||||
|
except Exception:
|
||||||
|
# If original app has no settings (old app), default to public to match fallback behavior
|
||||||
|
access_mode = "public"
|
||||||
|
|
||||||
|
# Apply the same access mode to the copied app
|
||||||
|
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
|
||||||
|
|
||||||
stmt = select(App).where(App.id == result.app_id)
|
stmt = select(App).where(App.id == result.app_id)
|
||||||
app = session.scalar(stmt)
|
app = session.scalar(stmt)
|
||||||
|
|
||||||
|
|||||||
@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
|||||||
current_user, current_tenant_id = current_account_with_tenant()
|
current_user, current_tenant_id = current_account_with_tenant()
|
||||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||||
return BuiltinToolManageService.set_default_provider(
|
return BuiltinToolManageService.set_default_provider(
|
||||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
tenant_id=current_tenant_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=provider,
|
||||||
|
id=args["id"],
|
||||||
|
account=current_user,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
|
|||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
|
from core.file import helpers as file_helpers
|
||||||
|
from core.file.enums import FileTransferMethod
|
||||||
from core.model_manager import ModelInstance
|
from core.model_manager import ModelInstance
|
||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||||
from core.model_runtime.entities.message_entities import (
|
from core.model_runtime.entities.message_entities import (
|
||||||
@ -57,10 +59,11 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
|||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||||
from core.telemetry import emit as telemetry_emit
|
from core.telemetry import emit as telemetry_emit
|
||||||
|
from core.tools.signature import sign_tool_file
|
||||||
from events.message_event import message_was_created
|
from events.message_event import message_was_created
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.model import AppMode, Conversation, Message, MessageAgentThought
|
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -473,6 +476,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
|||||||
metadata=metadata_dict,
|
metadata=metadata_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _record_files(self):
|
||||||
|
with Session(db.engine, expire_on_commit=False) as session:
|
||||||
|
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||||
|
if not message_files:
|
||||||
|
return None
|
||||||
|
|
||||||
|
files_list = []
|
||||||
|
upload_file_ids = [
|
||||||
|
mf.upload_file_id
|
||||||
|
for mf in message_files
|
||||||
|
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||||
|
]
|
||||||
|
upload_files_map = {}
|
||||||
|
if upload_file_ids:
|
||||||
|
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||||
|
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||||
|
|
||||||
|
for message_file in message_files:
|
||||||
|
upload_file = None
|
||||||
|
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||||
|
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||||
|
|
||||||
|
url = None
|
||||||
|
filename = "file"
|
||||||
|
mime_type = "application/octet-stream"
|
||||||
|
size = 0
|
||||||
|
extension = ""
|
||||||
|
|
||||||
|
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||||
|
url = message_file.url
|
||||||
|
if message_file.url:
|
||||||
|
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||||
|
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
|
if upload_file:
|
||||||
|
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||||
|
filename = upload_file.name
|
||||||
|
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||||
|
size = upload_file.size or 0
|
||||||
|
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||||
|
elif message_file.upload_file_id:
|
||||||
|
# Fallback: generate URL even if upload_file not found
|
||||||
|
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||||
|
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||||
|
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||||
|
if message_file.url.startswith("http"):
|
||||||
|
url = message_file.url
|
||||||
|
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||||
|
else:
|
||||||
|
# Extract tool file id and extension from URL
|
||||||
|
url_parts = message_file.url.split("/")
|
||||||
|
if url_parts:
|
||||||
|
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||||
|
# Use rsplit to correctly handle filenames with multiple dots
|
||||||
|
if "." in file_part:
|
||||||
|
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||||
|
extension = f".{ext}"
|
||||||
|
else:
|
||||||
|
tool_file_id = file_part
|
||||||
|
extension = ".bin"
|
||||||
|
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||||
|
filename = file_part
|
||||||
|
|
||||||
|
transfer_method_value = message_file.transfer_method
|
||||||
|
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||||
|
file_dict = {
|
||||||
|
"related_id": message_file.id,
|
||||||
|
"extension": extension,
|
||||||
|
"filename": filename,
|
||||||
|
"size": size,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"transfer_method": transfer_method_value,
|
||||||
|
"type": message_file.type,
|
||||||
|
"url": url or "",
|
||||||
|
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||||
|
"remote_url": remote_url,
|
||||||
|
}
|
||||||
|
files_list.append(file_dict)
|
||||||
|
return files_list or None
|
||||||
|
|
||||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||||
"""
|
"""
|
||||||
Agent message to stream response.
|
Agent message to stream response.
|
||||||
|
|||||||
@ -64,7 +64,13 @@ class MessageCycleManager:
|
|||||||
|
|
||||||
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
# Use SQLAlchemy 2.x style session.scalar(select(...))
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
|
message_file = session.scalar(
|
||||||
|
select(MessageFile)
|
||||||
|
.where(
|
||||||
|
MessageFile.message_id == message_id,
|
||||||
|
)
|
||||||
|
.where(MessageFile.belongs_to == "assistant")
|
||||||
|
)
|
||||||
|
|
||||||
if message_file:
|
if message_file:
|
||||||
self._message_has_file.add(message_id)
|
self._message_has_file.add(message_id)
|
||||||
|
|||||||
213
api/libs/db_migration_lock.py
Normal file
213
api/libs/db_migration_lock.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
DB migration Redis lock with heartbeat renewal.
|
||||||
|
|
||||||
|
This is intentionally migration-specific. Background renewal is a trade-off that makes sense
|
||||||
|
for unbounded, blocking operations like DB migrations (DDL/DML) where the main thread cannot
|
||||||
|
periodically refresh the lock TTL.
|
||||||
|
|
||||||
|
Do NOT use this as a general-purpose lock primitive for normal application code. Prefer explicit
|
||||||
|
lock lifecycle management (e.g. redis-py Lock context manager + `extend()` / `reacquire()` from
|
||||||
|
the same thread) when execution flow is under control.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from redis.exceptions import LockNotOwnedError, RedisError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
MIN_RENEW_INTERVAL_SECONDS = 0.1
|
||||||
|
DEFAULT_RENEW_INTERVAL_DIVISOR = 3
|
||||||
|
MIN_JOIN_TIMEOUT_SECONDS = 0.5
|
||||||
|
MAX_JOIN_TIMEOUT_SECONDS = 5.0
|
||||||
|
JOIN_TIMEOUT_MULTIPLIER = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class DbMigrationAutoRenewLock:
|
||||||
|
"""
|
||||||
|
Redis lock wrapper that automatically renews TTL while held (migration-only).
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- We force `thread_local=False` when creating the underlying redis-py lock, because the
|
||||||
|
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
|
||||||
|
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
|
||||||
|
primary error/exit code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_redis_client: Any
|
||||||
|
_name: str
|
||||||
|
_ttl_seconds: float
|
||||||
|
_renew_interval_seconds: float
|
||||||
|
_log_context: str | None
|
||||||
|
_logger: logging.Logger
|
||||||
|
|
||||||
|
_lock: Any
|
||||||
|
_stop_event: threading.Event | None
|
||||||
|
_thread: threading.Thread | None
|
||||||
|
_acquired: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
redis_client: Any,
|
||||||
|
name: str,
|
||||||
|
ttl_seconds: float = 60,
|
||||||
|
renew_interval_seconds: float | None = None,
|
||||||
|
*,
|
||||||
|
logger: logging.Logger | None = None,
|
||||||
|
log_context: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._redis_client = redis_client
|
||||||
|
self._name = name
|
||||||
|
self._ttl_seconds = float(ttl_seconds)
|
||||||
|
self._renew_interval_seconds = (
|
||||||
|
float(renew_interval_seconds)
|
||||||
|
if renew_interval_seconds is not None
|
||||||
|
else max(MIN_RENEW_INTERVAL_SECONDS, self._ttl_seconds / DEFAULT_RENEW_INTERVAL_DIVISOR)
|
||||||
|
)
|
||||||
|
self._logger = logger or logging.getLogger(__name__)
|
||||||
|
self._log_context = log_context
|
||||||
|
|
||||||
|
self._lock = None
|
||||||
|
self._stop_event = None
|
||||||
|
self._thread = None
|
||||||
|
self._acquired = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
def acquire(self, *args: Any, **kwargs: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Acquire the lock and start heartbeat renewal on success.
|
||||||
|
|
||||||
|
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
|
||||||
|
"""
|
||||||
|
# Prevent accidental double-acquire which could leave the previous heartbeat thread running.
|
||||||
|
if self._acquired:
|
||||||
|
raise RuntimeError("DB migration lock is already acquired; call release_safely() before acquiring again.")
|
||||||
|
|
||||||
|
# Reuse the lock object if we already created one.
|
||||||
|
if self._lock is None:
|
||||||
|
self._lock = self._redis_client.lock(
|
||||||
|
name=self._name,
|
||||||
|
timeout=self._ttl_seconds,
|
||||||
|
thread_local=False,
|
||||||
|
)
|
||||||
|
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||||
|
self._acquired = acquired
|
||||||
|
if acquired:
|
||||||
|
self._start_heartbeat()
|
||||||
|
return acquired
|
||||||
|
|
||||||
|
def owned(self) -> bool:
|
||||||
|
if self._lock is None:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return bool(self._lock.owned())
|
||||||
|
except Exception:
|
||||||
|
# Ownership checks are best-effort and must not break callers.
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _start_heartbeat(self) -> None:
|
||||||
|
if self._lock is None:
|
||||||
|
return
|
||||||
|
if self._stop_event is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._heartbeat_loop,
|
||||||
|
args=(self._lock, self._stop_event),
|
||||||
|
daemon=True,
|
||||||
|
name=f"DbMigrationAutoRenewLock({self._name})",
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||||
|
while not stop_event.wait(self._renew_interval_seconds):
|
||||||
|
try:
|
||||||
|
lock.reacquire()
|
||||||
|
except LockNotOwnedError:
|
||||||
|
self._logger.warning(
|
||||||
|
"DB migration lock is no longer owned during heartbeat; stop renewing. log_context=%s",
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
except RedisError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Failed to renew DB migration lock due to Redis error; will retry. log_context=%s",
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._logger.warning(
|
||||||
|
"Unexpected error while renewing DB migration lock; will retry. log_context=%s",
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def release_safely(self, *, status: str | None = None) -> None:
|
||||||
|
"""
|
||||||
|
Stop heartbeat and release lock. Never raises.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
|
||||||
|
"""
|
||||||
|
lock = self._lock
|
||||||
|
if lock is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._stop_heartbeat()
|
||||||
|
|
||||||
|
# Lock release errors should never mask the real error/exit code.
|
||||||
|
try:
|
||||||
|
lock.release()
|
||||||
|
except LockNotOwnedError:
|
||||||
|
self._logger.warning(
|
||||||
|
"DB migration lock not owned on release; ignoring. status=%s log_context=%s",
|
||||||
|
status,
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except RedisError:
|
||||||
|
self._logger.warning(
|
||||||
|
"Failed to release DB migration lock due to Redis error; ignoring. status=%s log_context=%s",
|
||||||
|
status,
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
self._logger.warning(
|
||||||
|
"Unexpected error while releasing DB migration lock; ignoring. status=%s log_context=%s",
|
||||||
|
status,
|
||||||
|
self._log_context,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._acquired = False
|
||||||
|
self._lock = None
|
||||||
|
|
||||||
|
def _stop_heartbeat(self) -> None:
|
||||||
|
if self._stop_event is None:
|
||||||
|
return
|
||||||
|
self._stop_event.set()
|
||||||
|
if self._thread is not None:
|
||||||
|
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
|
||||||
|
join_timeout_seconds = max(
|
||||||
|
MIN_JOIN_TIMEOUT_SECONDS,
|
||||||
|
min(MAX_JOIN_TIMEOUT_SECONDS, self._renew_interval_seconds * JOIN_TIMEOUT_MULTIPLIER),
|
||||||
|
)
|
||||||
|
self._thread.join(timeout=join_timeout_seconds)
|
||||||
|
if self._thread.is_alive():
|
||||||
|
self._logger.warning(
|
||||||
|
"DB migration lock heartbeat thread did not stop within %.2fs; ignoring. log_context=%s",
|
||||||
|
join_timeout_seconds,
|
||||||
|
self._log_context,
|
||||||
|
)
|
||||||
|
self._stop_event = None
|
||||||
|
self._thread = None
|
||||||
@ -8,7 +8,6 @@ Create Date: 2025-12-25 10:39:15.139304
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import models as models
|
import models as models
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = '7df29de0f6be'
|
revision = '7df29de0f6be'
|
||||||
@ -20,7 +19,7 @@ depends_on = None
|
|||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('tenant_credit_pools',
|
op.create_table('tenant_credit_pools',
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||||
|
|||||||
@ -8,7 +8,6 @@ Create Date: 2026-01-017 11:10:18.079355
|
|||||||
from alembic import op
|
from alembic import op
|
||||||
import models as models
|
import models as models
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = 'f9f6d18a37f9'
|
revision = 'f9f6d18a37f9'
|
||||||
@ -20,7 +19,7 @@ depends_on = None
|
|||||||
def upgrade():
|
def upgrade():
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
op.create_table('account_trial_app_records',
|
op.create_table('account_trial_app_records',
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
sa.Column('account_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('count', sa.Integer(), nullable=False),
|
sa.Column('count', sa.Integer(), nullable=False),
|
||||||
@ -33,17 +32,17 @@ def upgrade():
|
|||||||
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
|
||||||
|
|
||||||
op.create_table('exporle_banners',
|
op.create_table('exporle_banners',
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('content', sa.JSON(), nullable=False),
|
sa.Column('content', sa.JSON(), nullable=False),
|
||||||
sa.Column('link', sa.String(length=255), nullable=False),
|
sa.Column('link', sa.String(length=255), nullable=False),
|
||||||
sa.Column('sort', sa.Integer(), nullable=False),
|
sa.Column('sort', sa.Integer(), nullable=False),
|
||||||
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
|
sa.Column('status', sa.String(length=255), server_default='enabled', nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
|
sa.Column('language', sa.String(length=255), server_default='en-US', nullable=False),
|
||||||
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
|
||||||
)
|
)
|
||||||
op.create_table('trial_apps',
|
op.create_table('trial_apps',
|
||||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
|||||||
@ -620,7 +620,7 @@ class TrialApp(Base):
|
|||||||
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||||
app_id = mapped_column(StringUUID, nullable=False)
|
app_id = mapped_column(StringUUID, nullable=False)
|
||||||
tenant_id = mapped_column(StringUUID, nullable=False)
|
tenant_id = mapped_column(StringUUID, nullable=False)
|
||||||
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
|
|||||||
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
|
||||||
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
|
||||||
)
|
)
|
||||||
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||||
account_id = mapped_column(StringUUID, nullable=False)
|
account_id = mapped_column(StringUUID, nullable=False)
|
||||||
app_id = mapped_column(StringUUID, nullable=False)
|
app_id = mapped_column(StringUUID, nullable=False)
|
||||||
count = mapped_column(sa.Integer, nullable=False, default=0)
|
count = mapped_column(sa.Integer, nullable=False, default=0)
|
||||||
@ -660,18 +660,18 @@ class AccountTrialAppRecord(Base):
|
|||||||
class ExporleBanner(TypeBase):
|
class ExporleBanner(TypeBase):
|
||||||
__tablename__ = "exporle_banners"
|
__tablename__ = "exporle_banners"
|
||||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
|
||||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
|
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
|
||||||
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
|
||||||
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
link: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
|
||||||
status: Mapped[str] = mapped_column(
|
status: Mapped[str] = mapped_column(
|
||||||
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
|
sa.String(255), nullable=False, server_default='enabled', default="enabled"
|
||||||
)
|
)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
|
||||||
)
|
)
|
||||||
language: Mapped[str] = mapped_column(
|
language: Mapped[str] = mapped_column(
|
||||||
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
|
String(255), nullable=False, server_default='en-US', default="en-US"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -2166,7 +2166,7 @@ class TenantCreditPool(TypeBase):
|
|||||||
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
|
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
|
||||||
)
|
)
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
|
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
|
||||||
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
version = "1.12.0"
|
version = "1.12.1"
|
||||||
requires-python = ">=3.11,<3.13"
|
requires-python = ">=3.11,<3.13"
|
||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
@ -81,7 +81,7 @@ dependencies = [
|
|||||||
"starlette==0.49.1",
|
"starlette==0.49.1",
|
||||||
"tiktoken~=0.9.0",
|
"tiktoken~=0.9.0",
|
||||||
"transformers~=4.56.1",
|
"transformers~=4.56.1",
|
||||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||||
"yarl~=1.18.3",
|
"yarl~=1.18.3",
|
||||||
"webvtt-py~=0.5.1",
|
"webvtt-py~=0.5.1",
|
||||||
"sseclient-py~=1.8.0",
|
"sseclient-py~=1.8.0",
|
||||||
|
|||||||
@ -327,6 +327,12 @@ class AccountService:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def delete_account(account: Account):
|
def delete_account(account: Account):
|
||||||
"""Delete account. This method only adds a task to the queue for deletion."""
|
"""Delete account. This method only adds a task to the queue for deletion."""
|
||||||
|
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
|
||||||
|
from services.enterprise.account_deletion_sync import sync_account_deletion
|
||||||
|
|
||||||
|
sync_account_deletion(account_id=account.id, source="account_deleted")
|
||||||
|
|
||||||
|
# Now proceed with async account deletion
|
||||||
delete_account_task.delay(account.id)
|
delete_account_task.delay(account.id)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -1230,6 +1236,11 @@ class TenantService:
|
|||||||
if dify_config.BILLING_ENABLED:
|
if dify_config.BILLING_ENABLED:
|
||||||
BillingService.clean_billing_info_cache(tenant.id)
|
BillingService.clean_billing_info_cache(tenant.id)
|
||||||
|
|
||||||
|
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
|
||||||
|
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
|
||||||
|
|
||||||
|
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
|
||||||
"""Update member role"""
|
"""Update member role"""
|
||||||
|
|||||||
115
api/services/enterprise/account_deletion_sync.py
Normal file
115
api/services/enterprise/account_deletion_sync.py
Normal file
@ -0,0 +1,115 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from redis import RedisError
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from models.account import TenantAccountJoin
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
|
||||||
|
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
|
||||||
|
|
||||||
|
|
||||||
|
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||||
|
"""
|
||||||
|
Queue an account deletion sync task to Redis.
|
||||||
|
|
||||||
|
Internal helper function. Do not call directly - use the public functions instead.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace/tenant ID to sync
|
||||||
|
member_id: The member/account ID that was removed
|
||||||
|
source: Source of the sync request (for debugging/tracking)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if task was queued successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
task = {
|
||||||
|
"task_id": str(uuid.uuid4()),
|
||||||
|
"workspace_id": workspace_id,
|
||||||
|
"member_id": member_id,
|
||||||
|
"retry_count": 0,
|
||||||
|
"created_at": datetime.now(UTC).isoformat(),
|
||||||
|
"source": source,
|
||||||
|
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
|
||||||
|
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
|
||||||
|
workspace_id,
|
||||||
|
member_id,
|
||||||
|
task["task_id"],
|
||||||
|
source,
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
except (RedisError, TypeError) as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to queue account deletion sync for workspace %s, member %s: %s",
|
||||||
|
workspace_id,
|
||||||
|
member_id,
|
||||||
|
str(e),
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
# Don't raise - we don't want to fail member deletion if queueing fails
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
|
||||||
|
"""
|
||||||
|
Sync a single workspace member removal (enterprise only).
|
||||||
|
|
||||||
|
Queues a task for the enterprise backend to reassign resources from the removed member.
|
||||||
|
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workspace_id: The workspace/tenant ID
|
||||||
|
member_id: The member/account ID that was removed
|
||||||
|
source: Source of the sync request (e.g., "workspace_member_removed")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if task was queued (or skipped in community), False if queueing failed
|
||||||
|
"""
|
||||||
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
|
||||||
|
|
||||||
|
|
||||||
|
def sync_account_deletion(account_id: str, *, source: str) -> bool:
|
||||||
|
"""
|
||||||
|
Sync full account deletion across all workspaces (enterprise only).
|
||||||
|
|
||||||
|
Fetches all workspace memberships for the account and queues a sync task for each.
|
||||||
|
Handles enterprise edition check internally. Safe to call in community edition (no-op).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: The account ID being deleted
|
||||||
|
source: Source of the sync request (e.g., "account_deleted")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
|
||||||
|
"""
|
||||||
|
if not dify_config.ENTERPRISE_ENABLED:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Fetch all workspaces the account belongs to
|
||||||
|
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
|
||||||
|
|
||||||
|
# Queue sync task for each workspace
|
||||||
|
success = True
|
||||||
|
for join in workspace_joins:
|
||||||
|
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
|
||||||
|
success = False
|
||||||
|
|
||||||
|
return success
|
||||||
@ -4,6 +4,8 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from services.enterprise.base import EnterpriseRequest
|
from services.enterprise.base import EnterpriseRequest
|
||||||
|
|
||||||
|
ALLOWED_ACCESS_MODES = ["public", "private", "private_all", "sso_verified"]
|
||||||
|
|
||||||
|
|
||||||
class WebAppSettings(BaseModel):
|
class WebAppSettings(BaseModel):
|
||||||
access_mode: str = Field(
|
access_mode: str = Field(
|
||||||
@ -123,8 +125,8 @@ class EnterpriseService:
|
|||||||
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
def update_app_access_mode(cls, app_id: str, access_mode: str):
|
||||||
if not app_id:
|
if not app_id:
|
||||||
raise ValueError("app_id must be provided.")
|
raise ValueError("app_id must be provided.")
|
||||||
if access_mode not in ["public", "private", "private_all"]:
|
if access_mode not in ALLOWED_ACCESS_MODES:
|
||||||
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
|
raise ValueError(f"access_mode must be one of: {', '.join(ALLOWED_ACCESS_MODES)}")
|
||||||
|
|
||||||
data = {"appId": app_id, "accessMode": access_mode}
|
data = {"appId": app_id, "accessMode": access_mode}
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,10 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from models.account import Account
|
||||||
|
|
||||||
from sqlalchemy import exists, select
|
from sqlalchemy import exists, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
@ -406,20 +409,37 @@ class BuiltinToolManageService:
|
|||||||
return {"result": "success"}
|
return {"result": "success"}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
|
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None):
|
||||||
"""
|
"""
|
||||||
set default provider
|
set default provider
|
||||||
"""
|
"""
|
||||||
with Session(db.engine) as session:
|
with Session(db.engine) as session:
|
||||||
# get provider
|
# get provider (verify tenant ownership to prevent IDOR)
|
||||||
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
|
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
|
||||||
if target_provider is None:
|
if target_provider is None:
|
||||||
raise ValueError("provider not found")
|
raise ValueError("provider not found")
|
||||||
|
|
||||||
# clear default provider
|
# clear default provider
|
||||||
session.query(BuiltinToolProvider).filter_by(
|
if dify_config.ENTERPRISE_ENABLED:
|
||||||
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
# Enterprise: verify admin permission for tenant-wide operation
|
||||||
).update({"is_default": False})
|
from models.account import TenantAccountRole
|
||||||
|
|
||||||
|
if account is None:
|
||||||
|
# In enterprise mode, an account context is required to perform permission checks
|
||||||
|
raise ValueError("Account is required to set default credentials in enterprise mode")
|
||||||
|
|
||||||
|
if not TenantAccountRole.is_privileged_role(account.current_role):
|
||||||
|
raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode")
|
||||||
|
# Enterprise: clear ALL defaults for this provider in the tenant
|
||||||
|
# (regardless of user_id, since enterprise credentials may have different user_id)
|
||||||
|
session.query(BuiltinToolProvider).filter_by(
|
||||||
|
tenant_id=tenant_id, provider=provider, is_default=True
|
||||||
|
).update({"is_default": False})
|
||||||
|
else:
|
||||||
|
# Non-enterprise: only clear defaults for the current user
|
||||||
|
session.query(BuiltinToolProvider).filter_by(
|
||||||
|
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
|
||||||
|
).update({"is_default": False})
|
||||||
|
|
||||||
# set new default provider
|
# set new default provider
|
||||||
target_provider.is_default = True
|
target_provider.is_default = True
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from celery import shared_task
|
|||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@ -58,5 +57,3 @@ def add_annotation_to_index_task(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Build index for annotation failed")
|
logger.exception("Build index for annotation failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
|||||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Annotation deleted index failed")
|
logger.exception("Annotation deleted index failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from celery import shared_task
|
|||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@ -59,5 +58,3 @@ def update_annotation_to_index_task(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Build index for annotation failed")
|
logger.exception("Build index for annotation failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@ -14,6 +14,9 @@ from models.model import UploadFile
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Batch size for database operations to keep transactions short
|
||||||
|
BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
|
||||||
@shared_task(queue="dataset")
|
@shared_task(queue="dataset")
|
||||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||||
@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
|||||||
if not doc_form:
|
if not doc_form:
|
||||||
raise ValueError("doc_form is required")
|
raise ValueError("doc_form is required")
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
storage_keys_to_delete: list[str] = []
|
||||||
try:
|
index_node_ids: list[str] = []
|
||||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
segment_ids: list[str] = []
|
||||||
|
total_image_upload_file_ids: list[str] = []
|
||||||
if not dataset:
|
|
||||||
raise Exception("Document has no dataset")
|
|
||||||
|
|
||||||
session.query(DatasetMetadataBinding).where(
|
|
||||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
|
||||||
DatasetMetadataBinding.document_id.in_(document_ids),
|
|
||||||
).delete(synchronize_session=False)
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ============ Step 1: Query segment and file data (short read-only transaction) ============
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
# Get segments info
|
||||||
segments = session.scalars(
|
segments = session.scalars(
|
||||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||||
).all()
|
).all()
|
||||||
# check segment is exist
|
|
||||||
if segments:
|
if segments:
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
segment_ids = [segment.id for segment in segments]
|
||||||
index_processor.clean(
|
|
||||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Collect image file IDs from segment content
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
||||||
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
|
total_image_upload_file_ids.extend(image_upload_file_ids)
|
||||||
for image_file in image_files:
|
|
||||||
try:
|
# Query storage keys for image files
|
||||||
if image_file and image_file.key:
|
if total_image_upload_file_ids:
|
||||||
storage.delete(image_file.key)
|
image_files = session.scalars(
|
||||||
except Exception:
|
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
|
||||||
logger.exception(
|
).all()
|
||||||
"Delete image_files failed when storage deleted, \
|
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
|
||||||
image_upload_file_is: %s",
|
|
||||||
image_file.id,
|
# Query storage keys for document files
|
||||||
)
|
|
||||||
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
|
||||||
session.execute(stmt)
|
|
||||||
session.delete(segment)
|
|
||||||
if file_ids:
|
if file_ids:
|
||||||
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||||
for file in files:
|
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
|
||||||
try:
|
|
||||||
storage.delete(file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
|
|
||||||
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
|
||||||
session.execute(stmt)
|
|
||||||
|
|
||||||
session.commit()
|
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
|
||||||
|
if index_node_ids:
|
||||||
end_at = time.perf_counter()
|
try:
|
||||||
logger.info(
|
# Fetch dataset in a fresh session to avoid DetachedInstanceError
|
||||||
click.style(
|
with session_factory.create_session() as session:
|
||||||
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
fg="green",
|
if not dataset:
|
||||||
|
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
|
||||||
|
else:
|
||||||
|
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||||
|
index_processor.clean(
|
||||||
|
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
|
||||||
|
dataset_id,
|
||||||
|
document_ids,
|
||||||
|
len(index_node_ids),
|
||||||
)
|
)
|
||||||
)
|
|
||||||
|
# ============ Step 3: Delete metadata binding (separate short transaction) ============
|
||||||
|
try:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
deleted_count = (
|
||||||
|
session.query(DatasetMetadataBinding)
|
||||||
|
.where(
|
||||||
|
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||||
|
DatasetMetadataBinding.document_id.in_(document_ids),
|
||||||
|
)
|
||||||
|
.delete(synchronize_session=False)
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Cleaned documents when documents deleted failed")
|
logger.exception(
|
||||||
|
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
|
||||||
|
dataset_id,
|
||||||
|
document_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
|
||||||
|
if total_image_upload_file_ids:
|
||||||
|
failed_batches = 0
|
||||||
|
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||||
|
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
|
||||||
|
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
|
||||||
|
try:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
failed_batches += 1
|
||||||
|
logger.exception(
|
||||||
|
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
|
||||||
|
i,
|
||||||
|
i + len(batch),
|
||||||
|
dataset_id,
|
||||||
|
)
|
||||||
|
if failed_batches > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
|
||||||
|
failed_batches,
|
||||||
|
total_batches,
|
||||||
|
dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
|
||||||
|
if segment_ids:
|
||||||
|
failed_batches = 0
|
||||||
|
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
|
||||||
|
for i in range(0, len(segment_ids), BATCH_SIZE):
|
||||||
|
batch = segment_ids[i : i + BATCH_SIZE]
|
||||||
|
try:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
|
||||||
|
session.execute(segment_delete_stmt)
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
failed_batches += 1
|
||||||
|
logger.exception(
|
||||||
|
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
|
||||||
|
i,
|
||||||
|
i + len(batch),
|
||||||
|
dataset_id,
|
||||||
|
document_ids,
|
||||||
|
)
|
||||||
|
if failed_batches > 0:
|
||||||
|
logger.warning(
|
||||||
|
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
|
||||||
|
failed_batches,
|
||||||
|
total_batches,
|
||||||
|
document_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Step 6: Delete document-associated files (separate short transaction) ============
|
||||||
|
if file_ids:
|
||||||
|
try:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
|
||||||
|
session.execute(stmt)
|
||||||
|
session.commit()
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
|
||||||
|
dataset_id,
|
||||||
|
file_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
|
||||||
|
storage_delete_failures = 0
|
||||||
|
for storage_key in storage_keys_to_delete:
|
||||||
|
try:
|
||||||
|
storage.delete(storage_key)
|
||||||
|
except Exception:
|
||||||
|
storage_delete_failures += 1
|
||||||
|
logger.exception("Failed to delete file from storage, key: %s", storage_key)
|
||||||
|
if storage_delete_failures > 0:
|
||||||
|
logger.warning(
|
||||||
|
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
|
||||||
|
storage_delete_failures,
|
||||||
|
len(storage_keys_to_delete),
|
||||||
|
dataset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
click.style(
|
||||||
|
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
|
||||||
|
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
|
||||||
|
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
|
||||||
|
f"storage_files: {len(storage_keys_to_delete)}",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
|
||||||
|
dataset_id,
|
||||||
|
document_ids,
|
||||||
|
)
|
||||||
|
|||||||
@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
|||||||
|
|
||||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||||
|
|
||||||
|
# Initialize variables with default values
|
||||||
|
upload_file_key: str | None = None
|
||||||
|
dataset_config: dict | None = None
|
||||||
|
document_config: dict | None = None
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
try:
|
try:
|
||||||
dataset = session.get(Dataset, dataset_id)
|
dataset = session.get(Dataset, dataset_id)
|
||||||
@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
|||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError("UploadFile not found.")
|
raise ValueError("UploadFile not found.")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
dataset_config = {
|
||||||
suffix = Path(upload_file.key).suffix
|
"id": dataset.id,
|
||||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
"indexing_technique": dataset.indexing_technique,
|
||||||
storage.download(upload_file.key, file_path)
|
"tenant_id": dataset.tenant_id,
|
||||||
|
"embedding_model_provider": dataset.embedding_model_provider,
|
||||||
|
"embedding_model": dataset.embedding_model,
|
||||||
|
}
|
||||||
|
|
||||||
df = pd.read_csv(file_path)
|
document_config = {
|
||||||
content = []
|
"id": dataset_document.id,
|
||||||
for _, row in df.iterrows():
|
"doc_form": dataset_document.doc_form,
|
||||||
if dataset_document.doc_form == "qa_model":
|
"word_count": dataset_document.word_count or 0,
|
||||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
}
|
||||||
else:
|
|
||||||
data = {"content": row.iloc[0]}
|
|
||||||
content.append(data)
|
|
||||||
if len(content) == 0:
|
|
||||||
raise ValueError("The CSV file is empty.")
|
|
||||||
|
|
||||||
document_segments = []
|
upload_file_key = upload_file.key
|
||||||
embedding_model = None
|
|
||||||
if dataset.indexing_technique == "high_quality":
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=dataset.embedding_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
word_count_change = 0
|
except Exception:
|
||||||
if embedding_model:
|
logger.exception("Segments batch created index failed")
|
||||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
redis_client.setex(indexing_cache_key, 600, "error")
|
||||||
texts=[segment["content"] for segment in content]
|
return
|
||||||
)
|
|
||||||
|
# Ensure required variables are set before proceeding
|
||||||
|
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||||
|
logger.error("Required configuration not set due to session error")
|
||||||
|
redis_client.setex(indexing_cache_key, 600, "error")
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
suffix = Path(upload_file_key).suffix
|
||||||
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||||
|
storage.download(upload_file_key, file_path)
|
||||||
|
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
content = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
if document_config["doc_form"] == "qa_model":
|
||||||
|
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||||
else:
|
else:
|
||||||
tokens_list = [0] * len(content)
|
data = {"content": row.iloc[0]}
|
||||||
|
content.append(data)
|
||||||
|
if len(content) == 0:
|
||||||
|
raise ValueError("The CSV file is empty.")
|
||||||
|
|
||||||
for segment, tokens in zip(content, tokens_list):
|
document_segments = []
|
||||||
content = segment["content"]
|
embedding_model = None
|
||||||
doc_id = str(uuid.uuid4())
|
if dataset_config["indexing_technique"] == "high_quality":
|
||||||
segment_hash = helper.generate_text_hash(content)
|
model_manager = ModelManager()
|
||||||
max_position = (
|
embedding_model = model_manager.get_model_instance(
|
||||||
session.query(func.max(DocumentSegment.position))
|
tenant_id=dataset_config["tenant_id"],
|
||||||
.where(DocumentSegment.document_id == dataset_document.id)
|
provider=dataset_config["embedding_model_provider"],
|
||||||
.scalar()
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
model=dataset_config["embedding_model"],
|
||||||
segment_document = DocumentSegment(
|
)
|
||||||
tenant_id=tenant_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
document_id=document_id,
|
|
||||||
index_node_id=doc_id,
|
|
||||||
index_node_hash=segment_hash,
|
|
||||||
position=max_position + 1 if max_position else 1,
|
|
||||||
content=content,
|
|
||||||
word_count=len(content),
|
|
||||||
tokens=tokens,
|
|
||||||
created_by=user_id,
|
|
||||||
indexing_at=naive_utc_now(),
|
|
||||||
status="completed",
|
|
||||||
completed_at=naive_utc_now(),
|
|
||||||
)
|
|
||||||
if dataset_document.doc_form == "qa_model":
|
|
||||||
segment_document.answer = segment["answer"]
|
|
||||||
segment_document.word_count += len(segment["answer"])
|
|
||||||
word_count_change += segment_document.word_count
|
|
||||||
session.add(segment_document)
|
|
||||||
document_segments.append(segment_document)
|
|
||||||
|
|
||||||
|
word_count_change = 0
|
||||||
|
if embedding_model:
|
||||||
|
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||||
|
else:
|
||||||
|
tokens_list = [0] * len(content)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
for segment, tokens in zip(content, tokens_list):
|
||||||
|
content = segment["content"]
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
segment_hash = helper.generate_text_hash(content)
|
||||||
|
max_position = (
|
||||||
|
session.query(func.max(DocumentSegment.position))
|
||||||
|
.where(DocumentSegment.document_id == document_config["id"])
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
segment_document = DocumentSegment(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
document_id=document_id,
|
||||||
|
index_node_id=doc_id,
|
||||||
|
index_node_hash=segment_hash,
|
||||||
|
position=max_position + 1 if max_position else 1,
|
||||||
|
content=content,
|
||||||
|
word_count=len(content),
|
||||||
|
tokens=tokens,
|
||||||
|
created_by=user_id,
|
||||||
|
indexing_at=naive_utc_now(),
|
||||||
|
status="completed",
|
||||||
|
completed_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
if document_config["doc_form"] == "qa_model":
|
||||||
|
segment_document.answer = segment["answer"]
|
||||||
|
segment_document.word_count += len(segment["answer"])
|
||||||
|
word_count_change += segment_document.word_count
|
||||||
|
session.add(segment_document)
|
||||||
|
document_segments.append(segment_document)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
dataset_document = session.get(Document, document_id)
|
||||||
|
if dataset_document:
|
||||||
assert dataset_document.word_count is not None
|
assert dataset_document.word_count is not None
|
||||||
dataset_document.word_count += word_count_change
|
dataset_document.word_count += word_count_change
|
||||||
session.add(dataset_document)
|
session.add(dataset_document)
|
||||||
|
|
||||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
with session_factory.create_session() as session:
|
||||||
session.commit()
|
dataset = session.get(Dataset, dataset_id)
|
||||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
if dataset:
|
||||||
end_at = time.perf_counter()
|
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||||
logger.info(
|
|
||||||
click.style(
|
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
end_at = time.perf_counter()
|
||||||
fg="green",
|
logger.info(
|
||||||
)
|
click.style(
|
||||||
)
|
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||||
except Exception:
|
fg="green",
|
||||||
logger.exception("Segments batch created index failed")
|
)
|
||||||
redis_client.setex(indexing_cache_key, 600, "error")
|
)
|
||||||
|
|||||||
@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
"""
|
"""
|
||||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
total_attachment_files = []
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
try:
|
try:
|
||||||
@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
SegmentAttachmentBinding.document_id == document_id,
|
SegmentAttachmentBinding.document_id == document_id,
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
# check segment is exist
|
|
||||||
if segments:
|
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||||
|
|
||||||
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_contents = [segment.content for segment in segments]
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Cleaned document when document deleted failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# check segment is exist
|
||||||
|
if index_node_ids:
|
||||||
|
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
|
if dataset:
|
||||||
index_processor.clean(
|
index_processor.clean(
|
||||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for segment in segments:
|
total_image_files = []
|
||||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
with session_factory.create_session() as session, session.begin():
|
||||||
image_files = session.scalars(
|
for segment_content in segment_contents:
|
||||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||||
).all()
|
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||||
for image_file in image_files:
|
total_image_files.extend([image_file.key for image_file in image_files])
|
||||||
if image_file is None:
|
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||||
continue
|
session.execute(image_file_delete_stmt)
|
||||||
try:
|
|
||||||
storage.delete(image_file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Delete image_files failed when storage deleted, \
|
|
||||||
image_upload_file_is: %s",
|
|
||||||
image_file.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
with session_factory.create_session() as session, session.begin():
|
||||||
session.execute(image_file_delete_stmt)
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||||
session.delete(segment)
|
session.execute(segment_delete_stmt)
|
||||||
|
|
||||||
session.commit()
|
for image_file_key in total_image_files:
|
||||||
if file_id:
|
try:
|
||||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
storage.delete(image_file_key)
|
||||||
if file:
|
|
||||||
try:
|
|
||||||
storage.delete(file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
|
||||||
session.delete(file)
|
|
||||||
# delete segment attachments
|
|
||||||
if attachments_with_bindings:
|
|
||||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
|
||||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
|
||||||
for binding, attachment_file in attachments_with_bindings:
|
|
||||||
try:
|
|
||||||
storage.delete(attachment_file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Delete attachment_file failed when storage deleted, \
|
|
||||||
attachment_file_id: %s",
|
|
||||||
binding.attachment_id,
|
|
||||||
)
|
|
||||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
|
||||||
session.execute(attachment_file_delete_stmt)
|
|
||||||
|
|
||||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
|
||||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
|
||||||
)
|
|
||||||
session.execute(binding_delete_stmt)
|
|
||||||
|
|
||||||
# delete dataset metadata binding
|
|
||||||
session.query(DatasetMetadataBinding).where(
|
|
||||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
|
||||||
DatasetMetadataBinding.document_id == document_id,
|
|
||||||
).delete()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
click.style(
|
|
||||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
|
||||||
fg="green",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Cleaned document when document deleted failed")
|
logger.exception(
|
||||||
|
"Delete image_files failed when storage deleted, \
|
||||||
|
image_upload_file_is: %s",
|
||||||
|
image_file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
if file_id:
|
||||||
|
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||||
|
if file:
|
||||||
|
try:
|
||||||
|
storage.delete(file.key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||||
|
session.delete(file)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
# delete segment attachments
|
||||||
|
if attachment_ids:
|
||||||
|
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||||
|
session.execute(attachment_file_delete_stmt)
|
||||||
|
|
||||||
|
if binding_ids:
|
||||||
|
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||||
|
session.execute(binding_delete_stmt)
|
||||||
|
|
||||||
|
for attachment_file_key in total_attachment_files:
|
||||||
|
try:
|
||||||
|
storage.delete(attachment_file_key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Delete attachment_file failed when storage deleted, \
|
||||||
|
attachment_file_id: %s",
|
||||||
|
attachment_file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
# delete dataset metadata binding
|
||||||
|
session.query(DatasetMetadataBinding).where(
|
||||||
|
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||||
|
DatasetMetadataBinding.document_id == document_id,
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
click.style(
|
||||||
|
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
|||||||
"""
|
"""
|
||||||
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
total_index_node_ids = []
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
try:
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
|
||||||
|
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise Exception("Document has no dataset")
|
raise Exception("Document has no dataset")
|
||||||
index_type = dataset.doc_form
|
index_type = dataset.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
|
|
||||||
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
|
||||||
session.execute(document_delete_stmt)
|
session.execute(document_delete_stmt)
|
||||||
|
|
||||||
for document_id in document_ids:
|
for document_id in document_ids:
|
||||||
segments = session.scalars(
|
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
total_index_node_ids.extend([segment.index_node_id for segment in segments])
|
||||||
).all()
|
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
|
||||||
|
|
||||||
index_processor.clean(
|
with session_factory.create_session() as session:
|
||||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
)
|
if dataset:
|
||||||
segment_ids = [segment.id for segment in segments]
|
index_processor.clean(
|
||||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||||
session.execute(segment_delete_stmt)
|
|
||||||
session.commit()
|
|
||||||
end_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
click.style(
|
|
||||||
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
|
||||||
dataset_id, end_at - start_at
|
|
||||||
),
|
|
||||||
fg="green",
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
logger.exception("Cleaned document when import form notion document deleted failed")
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||||
|
session.execute(segment_delete_stmt)
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
click.style(
|
||||||
|
"Clean document when import form notion document deleted end :: {} latency: {}".format(
|
||||||
|
dataset_id, end_at - start_at
|
||||||
|
),
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import time
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
from sqlalchemy import delete
|
||||||
|
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
@ -67,8 +68,14 @@ def delete_segment_from_index_task(
|
|||||||
if segment_attachment_bindings:
|
if segment_attachment_bindings:
|
||||||
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
|
||||||
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
|
||||||
for binding in segment_attachment_bindings:
|
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
|
||||||
session.delete(binding)
|
|
||||||
|
for i in range(0, len(segment_attachment_bind_ids), 1000):
|
||||||
|
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
|
||||||
|
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
|
||||||
|
)
|
||||||
|
session.execute(segment_attachment_bind_delete_stmt)
|
||||||
|
|
||||||
# delete upload file
|
# delete upload file
|
||||||
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|||||||
@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
|||||||
"""
|
"""
|
||||||
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
tenant_id = None
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session, session.begin():
|
||||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if document.indexing_status == "parsing":
|
||||||
|
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
|
if not dataset:
|
||||||
|
raise Exception("Dataset not found")
|
||||||
|
|
||||||
data_source_info = document.data_source_info_dict
|
data_source_info = document.data_source_info_dict
|
||||||
if document.data_source_type == "notion_import":
|
if document.data_source_type != "notion_import":
|
||||||
if (
|
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
|
||||||
not data_source_info
|
return
|
||||||
or "notion_page_id" not in data_source_info
|
|
||||||
or "notion_workspace_id" not in data_source_info
|
|
||||||
):
|
|
||||||
raise ValueError("no notion page found")
|
|
||||||
workspace_id = data_source_info["notion_workspace_id"]
|
|
||||||
page_id = data_source_info["notion_page_id"]
|
|
||||||
page_type = data_source_info["type"]
|
|
||||||
page_edited_time = data_source_info["last_edited_time"]
|
|
||||||
credential_id = data_source_info.get("credential_id")
|
|
||||||
|
|
||||||
# Get credentials from datasource provider
|
if (
|
||||||
datasource_provider_service = DatasourceProviderService()
|
not data_source_info
|
||||||
credential = datasource_provider_service.get_datasource_credentials(
|
or "notion_page_id" not in data_source_info
|
||||||
tenant_id=document.tenant_id,
|
or "notion_workspace_id" not in data_source_info
|
||||||
credential_id=credential_id,
|
):
|
||||||
provider="notion_datasource",
|
raise ValueError("no notion page found")
|
||||||
plugin_id="langgenius/notion_datasource",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not credential:
|
workspace_id = data_source_info["notion_workspace_id"]
|
||||||
logger.error(
|
page_id = data_source_info["notion_page_id"]
|
||||||
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
page_type = data_source_info["type"]
|
||||||
document_id,
|
page_edited_time = data_source_info["last_edited_time"]
|
||||||
document.tenant_id,
|
credential_id = data_source_info.get("credential_id")
|
||||||
credential_id,
|
tenant_id = document.tenant_id
|
||||||
)
|
index_type = document.doc_form
|
||||||
|
|
||||||
|
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||||
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
|
# Get credentials from datasource provider
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credential = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
credential_id=credential_id,
|
||||||
|
provider="notion_datasource",
|
||||||
|
plugin_id="langgenius/notion_datasource",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not credential:
|
||||||
|
logger.error(
|
||||||
|
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
|
||||||
|
document_id,
|
||||||
|
tenant_id,
|
||||||
|
credential_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
|
if document:
|
||||||
document.indexing_status = "error"
|
document.indexing_status = "error"
|
||||||
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
|
||||||
document.stopped_at = naive_utc_now()
|
document.stopped_at = naive_utc_now()
|
||||||
session.commit()
|
return
|
||||||
return
|
|
||||||
|
|
||||||
loader = NotionExtractor(
|
loader = NotionExtractor(
|
||||||
notion_workspace_id=workspace_id,
|
notion_workspace_id=workspace_id,
|
||||||
notion_obj_id=page_id,
|
notion_obj_id=page_id,
|
||||||
notion_page_type=page_type,
|
notion_page_type=page_type,
|
||||||
notion_access_token=credential.get("integration_secret"),
|
notion_access_token=credential.get("integration_secret"),
|
||||||
tenant_id=document.tenant_id,
|
tenant_id=tenant_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
last_edited_time = loader.get_notion_last_edited_time()
|
last_edited_time = loader.get_notion_last_edited_time()
|
||||||
|
if last_edited_time == page_edited_time:
|
||||||
|
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
|
||||||
|
return
|
||||||
|
|
||||||
# check the page is updated
|
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
|
||||||
if last_edited_time != page_edited_time:
|
|
||||||
document.indexing_status = "parsing"
|
|
||||||
document.processing_started_at = naive_utc_now()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# delete all document segment and index
|
try:
|
||||||
try:
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
with session_factory.create_session() as session:
|
||||||
if not dataset:
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
raise Exception("Dataset not found")
|
if dataset:
|
||||||
index_type = document.doc_form
|
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to clean vector index for document %s", document_id)
|
||||||
|
|
||||||
segments = session.scalars(
|
with session_factory.create_session() as session, session.begin():
|
||||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
).all()
|
if not document:
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
|
||||||
|
return
|
||||||
|
|
||||||
# delete from vector index
|
data_source_info = document.data_source_info_dict
|
||||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
data_source_info["last_edited_time"] = last_edited_time
|
||||||
|
document.data_source_info = data_source_info
|
||||||
|
|
||||||
segment_ids = [segment.id for segment in segments]
|
document.indexing_status = "parsing"
|
||||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
document.processing_started_at = naive_utc_now()
|
||||||
session.execute(segment_delete_stmt)
|
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||||
logger.info(
|
session.execute(segment_delete_stmt)
|
||||||
click.style(
|
|
||||||
"Cleaned document when document update data source or process rule: {} latency: {}".format(
|
|
||||||
document_id, end_at - start_at
|
|
||||||
),
|
|
||||||
fg="green",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
|
||||||
|
|
||||||
try:
|
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
|
||||||
indexing_runner = IndexingRunner()
|
|
||||||
indexing_runner.run([document])
|
try:
|
||||||
end_at = time.perf_counter()
|
indexing_runner = IndexingRunner()
|
||||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
with session_factory.create_session() as session:
|
||||||
except DocumentIsPausedError as ex:
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
logger.info(click.style(str(ex), fg="yellow"))
|
if document:
|
||||||
except Exception:
|
indexing_runner.run([document])
|
||||||
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
|
end_at = time.perf_counter()
|
||||||
|
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
|
||||||
|
except DocumentIsPausedError as ex:
|
||||||
|
logger.info(click.style(str(ex), fg="yellow"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
document = session.query(Document).filter_by(id=document_id).first()
|
||||||
|
if document:
|
||||||
|
document.indexing_status = "error"
|
||||||
|
document.error = str(e)
|
||||||
|
document.stopped_at = naive_utc_now()
|
||||||
|
|||||||
@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
session.commit()
|
session.commit()
|
||||||
return
|
return
|
||||||
|
|
||||||
for document_id in document_ids:
|
# Phase 1: Update status to parsing (short transaction)
|
||||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
documents = (
|
||||||
document = (
|
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
if document:
|
if document:
|
||||||
document.indexing_status = "parsing"
|
document.indexing_status = "parsing"
|
||||||
document.processing_started_at = naive_utc_now()
|
document.processing_started_at = naive_utc_now()
|
||||||
documents.append(document)
|
|
||||||
session.add(document)
|
session.add(document)
|
||||||
session.commit()
|
# Transaction committed and closed
|
||||||
|
|
||||||
try:
|
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||||
indexing_runner = IndexingRunner()
|
has_error = False
|
||||||
indexing_runner.run(documents)
|
try:
|
||||||
end_at = time.perf_counter()
|
indexing_runner = IndexingRunner()
|
||||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
indexing_runner.run(documents)
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||||
|
except DocumentIsPausedError as ex:
|
||||||
|
logger.info(click.style(str(ex), fg="yellow"))
|
||||||
|
has_error = True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if not has_error:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
# Trigger summary index generation for completed documents if enabled
|
# Trigger summary index generation for completed documents if enabled
|
||||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||||
@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
# expire all session to get latest document's indexing status
|
# expire all session to get latest document's indexing status
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
# Check each document's indexing status and trigger summary generation if completed
|
# Check each document's indexing status and trigger summary generation if completed
|
||||||
for document_id in document_ids:
|
|
||||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
documents = (
|
||||||
document = (
|
session.query(Document)
|
||||||
session.query(Document)
|
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
.all()
|
||||||
.first()
|
)
|
||||||
)
|
|
||||||
|
for document in documents:
|
||||||
if document:
|
if document:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||||
document_id,
|
document.id,
|
||||||
document.indexing_status,
|
document.indexing_status,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
document.need_summary,
|
document.need_summary,
|
||||||
@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
and document.need_summary is True
|
and document.need_summary is True
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Queued summary index generation task for document %s in dataset %s "
|
"Queued summary index generation task for document %s in dataset %s "
|
||||||
"after indexing completed",
|
"after indexing completed",
|
||||||
document_id,
|
document.id,
|
||||||
dataset.id,
|
dataset.id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to queue summary index generation task for document %s",
|
"Failed to queue summary index generation task for document %s",
|
||||||
document_id,
|
document.id,
|
||||||
)
|
)
|
||||||
# Don't fail the entire indexing process if summary task queuing fails
|
# Don't fail the entire indexing process if summary task queuing fails
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping summary generation for document %s: "
|
"Skipping summary generation for document %s: "
|
||||||
"status=%s, doc_form=%s, need_summary=%s",
|
"status=%s, doc_form=%s, need_summary=%s",
|
||||||
document_id,
|
document.id,
|
||||||
document.indexing_status,
|
document.indexing_status,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
document.need_summary,
|
document.need_summary,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("Document %s not found after indexing", document_id)
|
logger.warning("Document %s not found after indexing", document.id)
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
|
||||||
dataset.id,
|
|
||||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||||
dataset.id,
|
dataset.id,
|
||||||
dataset.indexing_technique,
|
dataset.indexing_technique,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError as ex:
|
|
||||||
logger.info(click.style(str(ex), fg="yellow"))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _document_indexing_with_tenant_queue(
|
def _document_indexing_with_tenant_queue(
|
||||||
|
|||||||
@ -8,7 +8,6 @@ from sqlalchemy import delete, select
|
|||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.dataset import Dataset, Document, DocumentSegment
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session, session.begin():
|
||||||
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||||
|
|
||||||
if not document:
|
if not document:
|
||||||
@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
|
|
||||||
document.indexing_status = "parsing"
|
document.indexing_status = "parsing"
|
||||||
document.processing_started_at = naive_utc_now()
|
document.processing_started_at = naive_utc_now()
|
||||||
session.commit()
|
|
||||||
|
|
||||||
# delete all document segment and index
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
try:
|
if not dataset:
|
||||||
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
return
|
||||||
if not dataset:
|
|
||||||
raise Exception("Dataset not found")
|
|
||||||
|
|
||||||
index_type = document.doc_form
|
index_type = document.doc_form
|
||||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||||
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
|
||||||
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
clean_success = False
|
||||||
if segments:
|
try:
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||||
|
if index_node_ids:
|
||||||
# delete from vector index
|
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
|
||||||
segment_ids = [segment.id for segment in segments]
|
|
||||||
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
|
|
||||||
session.execute(segment_delete_stmt)
|
|
||||||
db.session.commit()
|
|
||||||
end_at = time.perf_counter()
|
end_at = time.perf_counter()
|
||||||
logger.info(
|
logger.info(
|
||||||
click.style(
|
click.style(
|
||||||
@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
|||||||
fg="green",
|
fg="green",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except Exception:
|
clean_success = True
|
||||||
logger.exception("Cleaned document when document update data source or process rule failed")
|
except Exception:
|
||||||
|
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
|
||||||
|
|
||||||
try:
|
if clean_success:
|
||||||
indexing_runner = IndexingRunner()
|
with session_factory.create_session() as session, session.begin():
|
||||||
indexing_runner.run([document])
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||||
end_at = time.perf_counter()
|
session.execute(segment_delete_stmt)
|
||||||
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
|
||||||
except DocumentIsPausedError as ex:
|
try:
|
||||||
logger.info(click.style(str(ex), fg="yellow"))
|
indexing_runner = IndexingRunner()
|
||||||
except Exception:
|
indexing_runner.run([document])
|
||||||
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
end_at = time.perf_counter()
|
||||||
|
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
|
||||||
|
except DocumentIsPausedError as ex:
|
||||||
|
logger.info(click.style(str(ex), fg="yellow"))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
|
||||||
|
|||||||
@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
|
|||||||
|
|
||||||
|
|
||||||
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
|
||||||
def del_workflow_archive_log(workflow_archive_log_id: str):
|
def del_workflow_archive_log(session, workflow_archive_log_id: str):
|
||||||
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
|
||||||
synchronize_session=False
|
synchronize_session=False
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
|||||||
total_files_deleted = 0
|
total_files_deleted = 0
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session, session.begin():
|
||||||
# Get a batch of draft variable IDs along with their file_ids
|
# Get a batch of draft variable IDs along with their file_ids
|
||||||
query_sql = """
|
query_sql = """
|
||||||
SELECT id, file_id FROM workflow_draft_variables
|
SELECT id, file_id FROM workflow_draft_variables
|
||||||
|
|||||||
@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from celery import shared_task # type: ignore[import-untyped]
|
from celery import shared_task # type: ignore[import-untyped]
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from core.db.session_factory import session_factory
|
||||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
@ -17,6 +16,6 @@ def save_workflow_execution_task(
|
|||||||
self,
|
self,
|
||||||
deletions: list[DraftVarFileDeletion],
|
deletions: list[DraftVarFileDeletion],
|
||||||
):
|
):
|
||||||
with Session(bind=db.engine) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
srv = WorkflowDraftVariableService(session=session)
|
srv = WorkflowDraftVariableService(session=session)
|
||||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||||
|
|||||||
@ -10,7 +10,10 @@ from models import Tenant
|
|||||||
from models.enums import CreatorUserRole
|
from models.enums import CreatorUserRole
|
||||||
from models.model import App, UploadFile
|
from models.model import App, UploadFile
|
||||||
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
|
||||||
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
|
from tasks.remove_app_and_related_data_task import (
|
||||||
|
_delete_draft_variables,
|
||||||
|
delete_draft_variables_batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||||||
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
|
||||||
data = setup_offload_test_data
|
data = setup_offload_test_data
|
||||||
app_id = data["app"].id
|
app_id = data["app"].id
|
||||||
|
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||||
|
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||||
mock_storage.delete.return_value = None
|
mock_storage.delete.return_value = None
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
var_files_before = session.query(WorkflowDraftVariableFile).count()
|
var_files_before = (
|
||||||
upload_files_before = session.query(UploadFile).count()
|
session.query(WorkflowDraftVariableFile)
|
||||||
|
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||||
assert draft_vars_before == 3
|
assert draft_vars_before == 3
|
||||||
assert var_files_before == 2
|
assert var_files_before == 2
|
||||||
assert upload_files_before == 2
|
assert upload_files_before == 2
|
||||||
@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||||||
assert draft_vars_after == 0
|
assert draft_vars_after == 0
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
var_files_after = (
|
||||||
upload_files_after = session.query(UploadFile).count()
|
session.query(WorkflowDraftVariableFile)
|
||||||
|
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||||
assert var_files_after == 0
|
assert var_files_after == 0
|
||||||
assert upload_files_after == 0
|
assert upload_files_after == 0
|
||||||
|
|
||||||
@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||||||
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
|
||||||
data = setup_offload_test_data
|
data = setup_offload_test_data
|
||||||
app_id = data["app"].id
|
app_id = data["app"].id
|
||||||
|
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||||
|
variable_file_ids = [vf.id for vf in data["variable_files"]]
|
||||||
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
mock_storage.delete.side_effect = [Exception("Storage error"), None]
|
||||||
|
|
||||||
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||||
@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||||||
assert draft_vars_after == 0
|
assert draft_vars_after == 0
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
var_files_after = session.query(WorkflowDraftVariableFile).count()
|
var_files_after = (
|
||||||
upload_files_after = session.query(UploadFile).count()
|
session.query(WorkflowDraftVariableFile)
|
||||||
|
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||||
assert var_files_after == 0
|
assert var_files_after == 0
|
||||||
assert upload_files_after == 0
|
assert upload_files_after == 0
|
||||||
|
|
||||||
@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
|
|||||||
if app2_obj:
|
if app2_obj:
|
||||||
session.delete(app2_obj)
|
session.delete(app2_obj)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteDraftVariablesSessionCommit:
|
||||||
|
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_offload_test_data(self, app_and_tenant):
|
||||||
|
"""Create test data with offload files for session commit tests."""
|
||||||
|
from core.variables.types import SegmentType
|
||||||
|
from libs.datetime_utils import naive_utc_now
|
||||||
|
|
||||||
|
tenant, app = app_and_tenant
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
upload_file1 = UploadFile(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
storage_type="local",
|
||||||
|
key="test/file1.json",
|
||||||
|
name="file1.json",
|
||||||
|
size=1024,
|
||||||
|
extension="json",
|
||||||
|
mime_type="application/json",
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=str(uuid.uuid4()),
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
used=False,
|
||||||
|
)
|
||||||
|
upload_file2 = UploadFile(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
storage_type="local",
|
||||||
|
key="test/file2.json",
|
||||||
|
name="file2.json",
|
||||||
|
size=2048,
|
||||||
|
extension="json",
|
||||||
|
mime_type="application/json",
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=str(uuid.uuid4()),
|
||||||
|
created_at=naive_utc_now(),
|
||||||
|
used=False,
|
||||||
|
)
|
||||||
|
session.add(upload_file1)
|
||||||
|
session.add(upload_file2)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
var_file1 = WorkflowDraftVariableFile(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
app_id=app.id,
|
||||||
|
user_id=str(uuid.uuid4()),
|
||||||
|
upload_file_id=upload_file1.id,
|
||||||
|
size=1024,
|
||||||
|
length=10,
|
||||||
|
value_type=SegmentType.STRING,
|
||||||
|
)
|
||||||
|
var_file2 = WorkflowDraftVariableFile(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
app_id=app.id,
|
||||||
|
user_id=str(uuid.uuid4()),
|
||||||
|
upload_file_id=upload_file2.id,
|
||||||
|
size=2048,
|
||||||
|
length=20,
|
||||||
|
value_type=SegmentType.OBJECT,
|
||||||
|
)
|
||||||
|
session.add(var_file1)
|
||||||
|
session.add(var_file2)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
draft_var1 = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=app.id,
|
||||||
|
node_id="node_1",
|
||||||
|
name="large_var_1",
|
||||||
|
value=StringSegment(value="truncated..."),
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
file_id=var_file1.id,
|
||||||
|
)
|
||||||
|
draft_var2 = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=app.id,
|
||||||
|
node_id="node_2",
|
||||||
|
name="large_var_2",
|
||||||
|
value=StringSegment(value="truncated..."),
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
file_id=var_file2.id,
|
||||||
|
)
|
||||||
|
draft_var3 = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=app.id,
|
||||||
|
node_id="node_3",
|
||||||
|
name="regular_var",
|
||||||
|
value=StringSegment(value="regular_value"),
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
session.add(draft_var1)
|
||||||
|
session.add(draft_var2)
|
||||||
|
session.add(draft_var3)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"app": app,
|
||||||
|
"tenant": tenant,
|
||||||
|
"upload_files": [upload_file1, upload_file2],
|
||||||
|
"variable_files": [var_file1, var_file2],
|
||||||
|
"draft_variables": [draft_var1, draft_var2, draft_var3],
|
||||||
|
}
|
||||||
|
|
||||||
|
yield data
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
for table, ids in [
|
||||||
|
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
|
||||||
|
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
|
||||||
|
(UploadFile, [uf.id for uf in data["upload_files"]]),
|
||||||
|
]:
|
||||||
|
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
|
||||||
|
session.execute(cleanup_query)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def setup_commit_test_data(self, app_and_tenant):
|
||||||
|
"""Create test data for session commit tests."""
|
||||||
|
tenant, app = app_and_tenant
|
||||||
|
variable_ids: list[str] = []
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
variables = []
|
||||||
|
for i in range(10):
|
||||||
|
var = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=app.id,
|
||||||
|
node_id=f"node_{i}",
|
||||||
|
name=f"var_{i}",
|
||||||
|
value=StringSegment(value="test_value"),
|
||||||
|
node_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
session.add(var)
|
||||||
|
variables.append(var)
|
||||||
|
session.commit()
|
||||||
|
variable_ids = [v.id for v in variables]
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"app": app,
|
||||||
|
"tenant": tenant,
|
||||||
|
"variable_ids": variable_ids,
|
||||||
|
}
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
cleanup_query = (
|
||||||
|
delete(WorkflowDraftVariable)
|
||||||
|
.where(WorkflowDraftVariable.id.in_(variable_ids))
|
||||||
|
.execution_options(synchronize_session=False)
|
||||||
|
)
|
||||||
|
session.execute(cleanup_query)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
|
||||||
|
"""Test that session.begin() is used for automatic transaction management."""
|
||||||
|
data = setup_commit_test_data
|
||||||
|
app_id = data["app"].id
|
||||||
|
|
||||||
|
# Since session.begin() is used, the transaction is automatically committed
|
||||||
|
# when the with block exits successfully. We verify this by checking that
|
||||||
|
# data is actually persisted.
|
||||||
|
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||||
|
|
||||||
|
# Verify all data was deleted (proves transaction was committed)
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
|
||||||
|
assert deleted_count == 10
|
||||||
|
assert remaining_count == 0
|
||||||
|
|
||||||
|
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
|
||||||
|
"""Test that data is actually persisted to database after batch deletion with commits."""
|
||||||
|
data = setup_commit_test_data
|
||||||
|
app_id = data["app"].id
|
||||||
|
variable_ids = data["variable_ids"]
|
||||||
|
|
||||||
|
# Verify initial state
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
assert initial_count == 10
|
||||||
|
|
||||||
|
# Perform deletion with small batch size to force multiple commits
|
||||||
|
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
|
||||||
|
|
||||||
|
assert deleted_count == 10
|
||||||
|
|
||||||
|
# Verify all data is deleted in a new session (proves commits worked)
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
assert final_count == 0
|
||||||
|
|
||||||
|
# Verify specific IDs are deleted
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
remaining_vars = (
|
||||||
|
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
|
||||||
|
)
|
||||||
|
assert remaining_vars == 0
|
||||||
|
|
||||||
|
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
|
||||||
|
"""Test session behavior when deleting from an empty dataset."""
|
||||||
|
nonexistent_app_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# Should not raise any errors and should return 0
|
||||||
|
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
|
||||||
|
assert deleted_count == 0
|
||||||
|
|
||||||
|
def test_session_commit_with_single_batch(self, setup_commit_test_data):
|
||||||
|
"""Test that commit happens correctly when all data fits in a single batch."""
|
||||||
|
data = setup_commit_test_data
|
||||||
|
app_id = data["app"].id
|
||||||
|
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
assert initial_count == 10
|
||||||
|
|
||||||
|
# Delete all in a single batch
|
||||||
|
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
|
||||||
|
assert deleted_count == 10
|
||||||
|
|
||||||
|
# Verify data is persisted
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
assert final_count == 0
|
||||||
|
|
||||||
|
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
|
||||||
|
"""Test that invalid batch size raises ValueError."""
|
||||||
|
data = setup_commit_test_data
|
||||||
|
app_id = data["app"].id
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||||
|
delete_draft_variables_batch(app_id, batch_size=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="batch_size must be positive"):
|
||||||
|
delete_draft_variables_batch(app_id, batch_size=-1)
|
||||||
|
|
||||||
|
@patch("extensions.ext_storage.storage")
|
||||||
|
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
|
||||||
|
"""Test that session commits correctly when cleaning up offload data."""
|
||||||
|
data = setup_offload_test_data
|
||||||
|
app_id = data["app"].id
|
||||||
|
upload_file_ids = [uf.id for uf in data["upload_files"]]
|
||||||
|
mock_storage.delete.return_value = None
|
||||||
|
|
||||||
|
# Verify initial state
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
var_files_before = (
|
||||||
|
session.query(WorkflowDraftVariableFile)
|
||||||
|
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||||
|
assert draft_vars_before == 3
|
||||||
|
assert var_files_before == 2
|
||||||
|
assert upload_files_before == 2
|
||||||
|
|
||||||
|
# Delete variables with offload data
|
||||||
|
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
|
||||||
|
assert deleted_count == 3
|
||||||
|
|
||||||
|
# Verify all data is persisted (deleted) in new session
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
|
||||||
|
var_files_after = (
|
||||||
|
session.query(WorkflowDraftVariableFile)
|
||||||
|
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
|
||||||
|
assert draft_vars_after == 0
|
||||||
|
assert var_files_after == 0
|
||||||
|
assert upload_files_after == 0
|
||||||
|
|
||||||
|
# Verify storage cleanup was called
|
||||||
|
assert mock_storage.delete.call_count == 2
|
||||||
|
|||||||
@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
Integration tests for DbMigrationAutoRenewLock using real Redis via TestContainers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from extensions.ext_redis import redis_client
|
||||||
|
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||||
|
def test_db_migration_lock_renews_ttl_and_releases():
|
||||||
|
lock_name = f"test:db_migration_auto_renew_lock:{uuid.uuid4().hex}"
|
||||||
|
|
||||||
|
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
|
||||||
|
lock = DbMigrationAutoRenewLock(
|
||||||
|
redis_client=redis_client,
|
||||||
|
name=lock_name,
|
||||||
|
ttl_seconds=1.0,
|
||||||
|
renew_interval_seconds=0.2,
|
||||||
|
log_context="test_db_migration_lock",
|
||||||
|
)
|
||||||
|
|
||||||
|
acquired = lock.acquire(blocking=True, blocking_timeout=5)
|
||||||
|
assert acquired is True
|
||||||
|
|
||||||
|
# Wait beyond the base TTL; key should still exist due to renewal.
|
||||||
|
time.sleep(1.5)
|
||||||
|
ttl = redis_client.ttl(lock_name)
|
||||||
|
assert ttl > 0
|
||||||
|
|
||||||
|
lock.release_safely(status="successful")
|
||||||
|
|
||||||
|
# After release, the key should not exist.
|
||||||
|
assert redis_client.exists(lock_name) == 0
|
||||||
@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
|
|
||||||
mock_storage.download.side_effect = mock_download
|
mock_storage.download.side_effect = mock_download
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task - should raise ValueError for empty CSV
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
batch_create_segment_to_index_task(
|
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||||
job_id=job_id,
|
batch_create_segment_to_index_task(
|
||||||
upload_file_id=upload_file.id,
|
job_id=job_id,
|
||||||
dataset_id=dataset.id,
|
upload_file_id=upload_file.id,
|
||||||
document_id=document.id,
|
dataset_id=dataset.id,
|
||||||
tenant_id=tenant.id,
|
document_id=document.id,
|
||||||
user_id=account.id,
|
tenant_id=tenant.id,
|
||||||
)
|
user_id=account.id,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify error handling
|
# Verify error handling
|
||||||
# Check Redis cache was set to error status
|
# Since exception was raised, no segments should be created
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
|
|
||||||
cache_key = f"segment_batch_import_{job_id}"
|
|
||||||
cache_value = redis_client.get(cache_key)
|
|
||||||
assert cache_value == b"error"
|
|
||||||
|
|
||||||
# Verify no segments were created
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).all()
|
segments = db.session.query(DocumentSegment).all()
|
||||||
|
|||||||
@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
# Execute cleanup task
|
# Execute cleanup task
|
||||||
clean_notion_document_task(document_ids, dataset.id)
|
clean_notion_document_task(document_ids, dataset.id)
|
||||||
|
|
||||||
# Verify documents and segments are deleted
|
# Verify segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment)
|
db_session_with_containers.query(DocumentSegment)
|
||||||
.filter(DocumentSegment.document_id.in_(document_ids))
|
.filter(DocumentSegment.document_id.in_(document_ids))
|
||||||
@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
|
|||||||
== 0
|
== 0
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify index processor was called for each document
|
# Verify index processor was called
|
||||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
assert mock_processor.clean.call_count == len(document_ids)
|
mock_processor.clean.assert_called_once()
|
||||||
|
|
||||||
# This test successfully verifies:
|
# This test successfully verifies:
|
||||||
# 1. Document records are properly deleted from the database
|
# 1. Document records are properly deleted from the database
|
||||||
@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
|
|||||||
non_existent_dataset_id = str(uuid.uuid4())
|
non_existent_dataset_id = str(uuid.uuid4())
|
||||||
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
|
||||||
|
|
||||||
# Execute cleanup task with non-existent dataset
|
# Execute cleanup task with non-existent dataset - expect exception
|
||||||
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
with pytest.raises(Exception, match="Document has no dataset"):
|
||||||
|
clean_notion_document_task(document_ids, non_existent_dataset_id)
|
||||||
|
|
||||||
# Verify that the index processor was not called
|
# Verify that the index processor factory was not used
|
||||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
|
||||||
mock_processor.clean.assert_not_called()
|
|
||||||
|
|
||||||
def test_clean_notion_document_task_empty_document_list(
|
def test_clean_notion_document_task_empty_document_list(
|
||||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||||
@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
|
|||||||
# Execute cleanup task with empty document list
|
# Execute cleanup task with empty document list
|
||||||
clean_notion_document_task([], dataset.id)
|
clean_notion_document_task([], dataset.id)
|
||||||
|
|
||||||
# Verify that the index processor was not called
|
# Verify that the index processor was called once with empty node list
|
||||||
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
mock_processor.clean.assert_not_called()
|
assert mock_processor.clean.call_count == 1
|
||||||
|
args, kwargs = mock_processor.clean.call_args
|
||||||
|
# args: (dataset, total_index_node_ids)
|
||||||
|
assert isinstance(args[0], Dataset)
|
||||||
|
assert args[1] == []
|
||||||
|
|
||||||
def test_clean_notion_document_task_with_different_index_types(
|
def test_clean_notion_document_task_with_different_index_types(
|
||||||
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
|
||||||
@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
# Note: This test successfully verifies cleanup with different document types.
|
# Note: This test successfully verifies cleanup with different document types.
|
||||||
# The task properly handles various index types and document configurations.
|
# The task properly handles various index types and document configurations.
|
||||||
|
|
||||||
# Verify documents and segments are deleted
|
# Verify segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment)
|
db_session_with_containers.query(DocumentSegment)
|
||||||
.filter(DocumentSegment.document_id == document.id)
|
.filter(DocumentSegment.document_id == document.id)
|
||||||
@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
# Execute cleanup task
|
# Execute cleanup task
|
||||||
clean_notion_document_task([document.id], dataset.id)
|
clean_notion_document_task([document.id], dataset.id)
|
||||||
|
|
||||||
# Verify documents and segments are deleted
|
# Verify segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||||
== 0
|
== 0
|
||||||
@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
|
|
||||||
clean_notion_document_task(documents_to_clean, dataset.id)
|
clean_notion_document_task(documents_to_clean, dataset.id)
|
||||||
|
|
||||||
# Verify only specified documents and segments are deleted
|
# Verify only specified documents' segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment)
|
db_session_with_containers.query(DocumentSegment)
|
||||||
.filter(DocumentSegment.document_id.in_(documents_to_clean))
|
.filter(DocumentSegment.document_id.in_(documents_to_clean))
|
||||||
@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
|
|||||||
db_session_with_containers.commit()
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
# Mock index processor to raise an exception
|
# Mock index processor to raise an exception
|
||||||
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
|
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
mock_index_processor.clean.side_effect = Exception("Index processor error")
|
||||||
|
|
||||||
# Execute cleanup task - it should handle the exception gracefully
|
# Execute cleanup task - current implementation propagates the exception
|
||||||
clean_notion_document_task([document.id], dataset.id)
|
with pytest.raises(Exception, match="Index processor error"):
|
||||||
|
clean_notion_document_task([document.id], dataset.id)
|
||||||
|
|
||||||
# Note: This test demonstrates the task's error handling capability.
|
# Note: This test demonstrates the task's error handling capability.
|
||||||
# Even with external service errors, the database operations complete successfully.
|
# Even with external service errors, the database operations complete successfully.
|
||||||
@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
all_document_ids = [doc.id for doc in documents]
|
all_document_ids = [doc.id for doc in documents]
|
||||||
clean_notion_document_task(all_document_ids, dataset.id)
|
clean_notion_document_task(all_document_ids, dataset.id)
|
||||||
|
|
||||||
# Verify all documents and segments are deleted
|
# Verify all segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||||
== 0
|
== 0
|
||||||
@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
|
|
||||||
clean_notion_document_task([target_document.id], target_dataset.id)
|
clean_notion_document_task([target_document.id], target_dataset.id)
|
||||||
|
|
||||||
# Verify only documents from target dataset are deleted
|
# Verify only documents' segments from target dataset are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment)
|
db_session_with_containers.query(DocumentSegment)
|
||||||
.filter(DocumentSegment.document_id == target_document.id)
|
.filter(DocumentSegment.document_id == target_document.id)
|
||||||
@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
all_document_ids = [doc.id for doc in documents]
|
all_document_ids = [doc.id for doc in documents]
|
||||||
clean_notion_document_task(all_document_ids, dataset.id)
|
clean_notion_document_task(all_document_ids, dataset.id)
|
||||||
|
|
||||||
# Verify all documents and segments are deleted regardless of status
|
# Verify all segments are deleted regardless of status
|
||||||
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
|
||||||
== 0
|
== 0
|
||||||
@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
|
|||||||
# Execute cleanup task
|
# Execute cleanup task
|
||||||
clean_notion_document_task([document.id], dataset.id)
|
clean_notion_document_task([document.id], dataset.id)
|
||||||
|
|
||||||
# Verify documents and segments are deleted
|
# Verify segments are deleted
|
||||||
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
|
|
||||||
assert (
|
assert (
|
||||||
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
|
||||||
== 0
|
== 0
|
||||||
|
|||||||
@ -0,0 +1,182 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from faker import Faker
|
||||||
|
|
||||||
|
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||||
|
|
||||||
|
|
||||||
|
class TestDocumentIndexingUpdateTask:
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_external_dependencies(self):
|
||||||
|
"""Patch external collaborators used by the update task.
|
||||||
|
- IndexProcessorFactory.init_index_processor().clean(...)
|
||||||
|
- IndexingRunner.run([...])
|
||||||
|
"""
|
||||||
|
with (
|
||||||
|
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
|
||||||
|
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
|
||||||
|
):
|
||||||
|
processor_instance = MagicMock()
|
||||||
|
mock_factory.return_value.init_index_processor.return_value = processor_instance
|
||||||
|
|
||||||
|
runner_instance = MagicMock()
|
||||||
|
mock_runner.return_value = runner_instance
|
||||||
|
|
||||||
|
yield {
|
||||||
|
"factory": mock_factory,
|
||||||
|
"processor": processor_instance,
|
||||||
|
"runner": mock_runner,
|
||||||
|
"runner_instance": runner_instance,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
|
||||||
|
fake = Faker()
|
||||||
|
|
||||||
|
# Account and tenant
|
||||||
|
account = Account(
|
||||||
|
email=fake.email(),
|
||||||
|
name=fake.name(),
|
||||||
|
interface_language="en-US",
|
||||||
|
status="active",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(account)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
tenant = Tenant(name=fake.company(), status="normal")
|
||||||
|
db_session_with_containers.add(tenant)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
join = TenantAccountJoin(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
account_id=account.id,
|
||||||
|
role=TenantAccountRole.OWNER,
|
||||||
|
current=True,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(join)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Dataset and document
|
||||||
|
dataset = Dataset(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
name=fake.company(),
|
||||||
|
description=fake.text(max_nb_chars=64),
|
||||||
|
data_source_type="upload_file",
|
||||||
|
indexing_technique="high_quality",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(dataset)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
position=0,
|
||||||
|
data_source_type="upload_file",
|
||||||
|
batch="test_batch",
|
||||||
|
name=fake.file_name(),
|
||||||
|
created_from="upload_file",
|
||||||
|
created_by=account.id,
|
||||||
|
indexing_status="waiting",
|
||||||
|
enabled=True,
|
||||||
|
doc_form="text_model",
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(document)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Segments
|
||||||
|
node_ids = []
|
||||||
|
for i in range(segment_count):
|
||||||
|
node_id = f"node-{i + 1}"
|
||||||
|
seg = DocumentSegment(
|
||||||
|
tenant_id=tenant.id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
document_id=document.id,
|
||||||
|
position=i,
|
||||||
|
content=fake.text(max_nb_chars=32),
|
||||||
|
answer=None,
|
||||||
|
word_count=10,
|
||||||
|
tokens=5,
|
||||||
|
index_node_id=node_id,
|
||||||
|
status="completed",
|
||||||
|
created_by=account.id,
|
||||||
|
)
|
||||||
|
db_session_with_containers.add(seg)
|
||||||
|
node_ids.append(node_id)
|
||||||
|
db_session_with_containers.commit()
|
||||||
|
|
||||||
|
# Refresh to ensure ORM state
|
||||||
|
db_session_with_containers.refresh(dataset)
|
||||||
|
db_session_with_containers.refresh(document)
|
||||||
|
|
||||||
|
return dataset, document, node_ids
|
||||||
|
|
||||||
|
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||||
|
|
||||||
|
# Act
|
||||||
|
document_indexing_update_task(dataset.id, document.id)
|
||||||
|
|
||||||
|
# Ensure we see committed changes from another session
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
# Assert document status updated before reindex
|
||||||
|
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
|
||||||
|
assert updated.indexing_status == "parsing"
|
||||||
|
assert updated.processing_started_at is not None
|
||||||
|
|
||||||
|
# Segments should be deleted
|
||||||
|
remaining = (
|
||||||
|
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||||
|
)
|
||||||
|
assert remaining == 0
|
||||||
|
|
||||||
|
# Assert index processor clean was called with expected args
|
||||||
|
clean_call = mock_external_dependencies["processor"].clean.call_args
|
||||||
|
assert clean_call is not None
|
||||||
|
args, kwargs = clean_call
|
||||||
|
# args[0] is a Dataset instance (from another session) — validate by id
|
||||||
|
assert getattr(args[0], "id", None) == dataset.id
|
||||||
|
# args[1] should contain our node_ids
|
||||||
|
assert set(args[1]) == set(node_ids)
|
||||||
|
assert kwargs.get("with_keywords") is True
|
||||||
|
assert kwargs.get("delete_child_chunks") is True
|
||||||
|
|
||||||
|
# Assert indexing runner invoked with the updated document
|
||||||
|
run_call = mock_external_dependencies["runner_instance"].run.call_args
|
||||||
|
assert run_call is not None
|
||||||
|
run_docs = run_call[0][0]
|
||||||
|
assert len(run_docs) == 1
|
||||||
|
first = run_docs[0]
|
||||||
|
assert getattr(first, "id", None) == document.id
|
||||||
|
|
||||||
|
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
|
||||||
|
|
||||||
|
# Force clean to raise; task should continue to indexing
|
||||||
|
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
|
||||||
|
|
||||||
|
document_indexing_update_task(dataset.id, document.id)
|
||||||
|
|
||||||
|
# Ensure we see committed changes from another session
|
||||||
|
db_session_with_containers.expire_all()
|
||||||
|
|
||||||
|
# Indexing should still be triggered
|
||||||
|
mock_external_dependencies["runner_instance"].run.assert_called_once()
|
||||||
|
|
||||||
|
# Segments should remain (since clean failed before DB delete)
|
||||||
|
remaining = (
|
||||||
|
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
|
||||||
|
)
|
||||||
|
assert remaining > 0
|
||||||
|
|
||||||
|
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
|
||||||
|
fake = Faker()
|
||||||
|
# Act with non-existent document id
|
||||||
|
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
|
||||||
|
|
||||||
|
# Neither processor nor runner should be called
|
||||||
|
mock_external_dependencies["processor"].clean.assert_not_called()
|
||||||
|
mock_external_dependencies["runner_instance"].run.assert_not_called()
|
||||||
146
api/tests/unit_tests/commands/test_upgrade_db.py
Normal file
146
api/tests/unit_tests/commands/test_upgrade_db.py
Normal file
@ -0,0 +1,146 @@
|
|||||||
|
import sys
|
||||||
|
import threading
|
||||||
|
import types
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import commands
|
||||||
|
from libs.db_migration_lock import LockNotOwnedError, RedisError
|
||||||
|
|
||||||
|
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
|
||||||
|
|
||||||
|
|
||||||
|
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
|
||||||
|
module = types.ModuleType("flask_migrate")
|
||||||
|
module.upgrade = upgrade_impl
|
||||||
|
monkeypatch.setitem(sys.modules, "flask_migrate", module)
|
||||||
|
|
||||||
|
|
||||||
|
def _invoke_upgrade_db() -> int:
|
||||||
|
try:
|
||||||
|
commands.upgrade_db.callback()
|
||||||
|
except SystemExit as e:
|
||||||
|
return int(e.code or 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
|
||||||
|
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
|
||||||
|
|
||||||
|
lock = MagicMock()
|
||||||
|
lock.acquire.return_value = False
|
||||||
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
|
exit_code = _invoke_upgrade_db()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert "Database migration skipped" in captured.out
|
||||||
|
|
||||||
|
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
|
||||||
|
lock.acquire.assert_called_once_with(blocking=False)
|
||||||
|
lock.release.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
|
||||||
|
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
|
||||||
|
|
||||||
|
lock = MagicMock()
|
||||||
|
lock.acquire.return_value = True
|
||||||
|
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||||
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
|
def _upgrade():
|
||||||
|
raise RuntimeError("boom")
|
||||||
|
|
||||||
|
_install_fake_flask_migrate(monkeypatch, _upgrade)
|
||||||
|
|
||||||
|
exit_code = _invoke_upgrade_db()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert exit_code == 1
|
||||||
|
assert "Database migration failed: boom" in captured.out
|
||||||
|
|
||||||
|
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
|
||||||
|
lock.acquire.assert_called_once_with(blocking=False)
|
||||||
|
lock.release.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys):
|
||||||
|
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
|
||||||
|
|
||||||
|
lock = MagicMock()
|
||||||
|
lock.acquire.return_value = True
|
||||||
|
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||||
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
|
_install_fake_flask_migrate(monkeypatch, lambda: None)
|
||||||
|
|
||||||
|
exit_code = _invoke_upgrade_db()
|
||||||
|
captured = capsys.readouterr()
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert "Database migration successful!" in captured.out
|
||||||
|
|
||||||
|
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
|
||||||
|
lock.acquire.assert_called_once_with(blocking=False)
|
||||||
|
lock.release.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
|
||||||
|
"""
|
||||||
|
Ensure the lock is renewed while migrations are running, so the base TTL can stay short.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Use a small TTL so the heartbeat interval triggers quickly.
|
||||||
|
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||||
|
|
||||||
|
lock = MagicMock()
|
||||||
|
lock.acquire.return_value = True
|
||||||
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
|
renewed = threading.Event()
|
||||||
|
|
||||||
|
def _reacquire():
|
||||||
|
renewed.set()
|
||||||
|
return True
|
||||||
|
|
||||||
|
lock.reacquire.side_effect = _reacquire
|
||||||
|
|
||||||
|
def _upgrade():
|
||||||
|
assert renewed.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
|
||||||
|
|
||||||
|
_install_fake_flask_migrate(monkeypatch, _upgrade)
|
||||||
|
|
||||||
|
exit_code = _invoke_upgrade_db()
|
||||||
|
_ = capsys.readouterr()
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert lock.reacquire.call_count >= 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
|
||||||
|
# Use a small TTL so heartbeat runs during the upgrade call.
|
||||||
|
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
|
||||||
|
|
||||||
|
lock = MagicMock()
|
||||||
|
lock.acquire.return_value = True
|
||||||
|
commands.redis_client.lock.return_value = lock
|
||||||
|
|
||||||
|
attempted = threading.Event()
|
||||||
|
|
||||||
|
def _reacquire():
|
||||||
|
attempted.set()
|
||||||
|
raise RedisError("simulated")
|
||||||
|
|
||||||
|
lock.reacquire.side_effect = _reacquire
|
||||||
|
|
||||||
|
def _upgrade():
|
||||||
|
assert attempted.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
|
||||||
|
|
||||||
|
_install_fake_flask_migrate(monkeypatch, _upgrade)
|
||||||
|
|
||||||
|
exit_code = _invoke_upgrade_db()
|
||||||
|
_ = capsys.readouterr()
|
||||||
|
|
||||||
|
assert exit_code == 0
|
||||||
|
assert lock.reacquire.call_count >= 1
|
||||||
@ -25,15 +25,19 @@ class TestMessageCycleManagerOptimization:
|
|||||||
task_state = Mock()
|
task_state = Mock()
|
||||||
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
|
||||||
|
|
||||||
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
|
def test_get_message_event_type_with_assistant_file(self, message_cycle_manager):
|
||||||
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
|
"""Test get_message_event_type returns MESSAGE_FILE when message has assistant-generated files.
|
||||||
|
|
||||||
|
This ensures that AI-generated images (belongs_to='assistant') trigger the MESSAGE_FILE event,
|
||||||
|
allowing the frontend to properly display generated image files with url field.
|
||||||
|
"""
|
||||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
# Setup mock session and message file
|
# Setup mock session and message file
|
||||||
mock_session = Mock()
|
mock_session = Mock()
|
||||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
mock_message_file = Mock()
|
mock_message_file = Mock()
|
||||||
# Current implementation uses session.scalar(select(...))
|
mock_message_file.belongs_to = "assistant"
|
||||||
mock_session.scalar.return_value = mock_message_file
|
mock_session.scalar.return_value = mock_message_file
|
||||||
|
|
||||||
# Execute
|
# Execute
|
||||||
@ -44,6 +48,31 @@ class TestMessageCycleManagerOptimization:
|
|||||||
assert result == StreamEvent.MESSAGE_FILE
|
assert result == StreamEvent.MESSAGE_FILE
|
||||||
mock_session.scalar.assert_called_once()
|
mock_session.scalar.assert_called_once()
|
||||||
|
|
||||||
|
def test_get_message_event_type_with_user_file(self, message_cycle_manager):
|
||||||
|
"""Test get_message_event_type returns MESSAGE when message only has user-uploaded files.
|
||||||
|
|
||||||
|
This is a regression test for the issue where user-uploaded images (belongs_to='user')
|
||||||
|
caused the LLM text response to be incorrectly tagged with MESSAGE_FILE event,
|
||||||
|
resulting in broken images in the chat UI. The query filters for belongs_to='assistant',
|
||||||
|
so when only user files exist, the database query returns None, resulting in MESSAGE event type.
|
||||||
|
"""
|
||||||
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
|
# Setup mock session and message file
|
||||||
|
mock_session = Mock()
|
||||||
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
|
# When querying for assistant files with only user files present, return None
|
||||||
|
# (simulates database query with belongs_to='assistant' filter returning no results)
|
||||||
|
mock_session.scalar.return_value = None
|
||||||
|
|
||||||
|
# Execute
|
||||||
|
with current_app.app_context():
|
||||||
|
result = message_cycle_manager.get_message_event_type("test-message-id")
|
||||||
|
|
||||||
|
# Assert
|
||||||
|
assert result == StreamEvent.MESSAGE
|
||||||
|
mock_session.scalar.assert_called_once()
|
||||||
|
|
||||||
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
|
||||||
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
"""Test get_message_event_type returns MESSAGE when message has no files."""
|
||||||
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
|
||||||
@ -69,7 +98,7 @@ class TestMessageCycleManagerOptimization:
|
|||||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||||
|
|
||||||
mock_message_file = Mock()
|
mock_message_file = Mock()
|
||||||
# Current implementation uses session.scalar(select(...))
|
mock_message_file.belongs_to = "assistant"
|
||||||
mock_session.scalar.return_value = mock_message_file
|
mock_session.scalar.return_value = mock_message_file
|
||||||
|
|
||||||
# Execute: compute event type once, then pass to message_to_stream_response
|
# Execute: compute event type once, then pass to message_to_stream_response
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from typing import Any
|
|||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from hypothesis import given, settings
|
from hypothesis import HealthCheck, given, settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
from core.file import File, FileTransferMethod, FileType
|
from core.file import File, FileTransferMethod, FileType
|
||||||
@ -493,7 +493,7 @@ def _scalar_value() -> st.SearchStrategy[int | float | str | File | None]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@settings(max_examples=50)
|
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||||
@given(_scalar_value())
|
@given(_scalar_value())
|
||||||
def test_build_segment_and_extract_values_for_scalar_types(value):
|
def test_build_segment_and_extract_values_for_scalar_types(value):
|
||||||
seg = variable_factory.build_segment(value)
|
seg = variable_factory.build_segment(value)
|
||||||
@ -504,7 +504,7 @@ def test_build_segment_and_extract_values_for_scalar_types(value):
|
|||||||
assert seg.value == value
|
assert seg.value == value
|
||||||
|
|
||||||
|
|
||||||
@settings(max_examples=50)
|
@settings(max_examples=30, suppress_health_check=[HealthCheck.too_slow, HealthCheck.filter_too_much], deadline=None)
|
||||||
@given(values=st.lists(_scalar_value(), max_size=20))
|
@given(values=st.lists(_scalar_value(), max_size=20))
|
||||||
def test_build_segment_and_extract_values_for_array_types(values):
|
def test_build_segment_and_extract_values_for_array_types(values):
|
||||||
seg = variable_factory.build_segment(values)
|
seg = variable_factory.build_segment(values)
|
||||||
|
|||||||
@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
|
|||||||
def mock_db_session():
|
def mock_db_session():
|
||||||
"""Mock database session via session_factory.create_session()."""
|
"""Mock database session via session_factory.create_session()."""
|
||||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||||
session = MagicMock()
|
sessions = [] # Track all created sessions
|
||||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
# Shared mock data that all sessions will access
|
||||||
session.close = MagicMock()
|
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||||
cm = MagicMock()
|
|
||||||
cm.__enter__.return_value = session
|
|
||||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
|
||||||
|
|
||||||
def _exit_side_effect(*args, **kwargs):
|
def create_session_side_effect():
|
||||||
session.close()
|
session = MagicMock()
|
||||||
|
session.close = MagicMock()
|
||||||
|
|
||||||
cm.__exit__.side_effect = _exit_side_effect
|
# Track commit calls
|
||||||
mock_sf.create_session.return_value = cm
|
commit_mock = MagicMock()
|
||||||
|
session.commit = commit_mock
|
||||||
|
cm = MagicMock()
|
||||||
|
cm.__enter__.return_value = session
|
||||||
|
|
||||||
query = MagicMock()
|
def _exit_side_effect(*args, **kwargs):
|
||||||
session.query.return_value = query
|
session.close()
|
||||||
query.where.return_value = query
|
|
||||||
yield session
|
cm.__exit__.side_effect = _exit_side_effect
|
||||||
|
|
||||||
|
# Support session.begin() for transactions
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = session
|
||||||
|
|
||||||
|
def begin_exit_side_effect(*args, **kwargs):
|
||||||
|
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||||
|
session.commit()
|
||||||
|
# Also mark wrapper's commit as called
|
||||||
|
if sessions:
|
||||||
|
sessions[0].commit()
|
||||||
|
|
||||||
|
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||||
|
session.begin = MagicMock(return_value=begin_cm)
|
||||||
|
|
||||||
|
sessions.append(session)
|
||||||
|
|
||||||
|
# Setup query with side_effect to handle both Dataset and Document queries
|
||||||
|
def query_side_effect(*args):
|
||||||
|
query = MagicMock()
|
||||||
|
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.first.return_value = shared_mock_data["dataset"]
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||||
|
# Support both .first() and .all() calls with chaining
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.where = MagicMock(return_value=where_result)
|
||||||
|
|
||||||
|
# Create an iterator for .first() calls if not exists
|
||||||
|
if shared_mock_data["doc_iter"] is None:
|
||||||
|
docs = shared_mock_data["documents"] or [None]
|
||||||
|
shared_mock_data["doc_iter"] = iter(docs)
|
||||||
|
|
||||||
|
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||||
|
docs_or_empty = shared_mock_data["documents"] or []
|
||||||
|
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
else:
|
||||||
|
query.where = MagicMock(return_value=query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
session.query = MagicMock(side_effect=query_side_effect)
|
||||||
|
return cm
|
||||||
|
|
||||||
|
mock_sf.create_session.side_effect = create_session_side_effect
|
||||||
|
|
||||||
|
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||||
|
class SessionWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self._sessions = sessions
|
||||||
|
self._shared_data = shared_mock_data
|
||||||
|
# Create a default session for setup phase
|
||||||
|
self._default_session = MagicMock()
|
||||||
|
self._default_session.close = MagicMock()
|
||||||
|
self._default_session.commit = MagicMock()
|
||||||
|
|
||||||
|
# Support session.begin() for default session too
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = self._default_session
|
||||||
|
|
||||||
|
def default_begin_exit_side_effect(*args, **kwargs):
|
||||||
|
self._default_session.commit()
|
||||||
|
|
||||||
|
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||||
|
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||||
|
|
||||||
|
def default_query_side_effect(*args):
|
||||||
|
query = MagicMock()
|
||||||
|
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.first.return_value = shared_mock_data["dataset"]
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.where = MagicMock(return_value=where_result)
|
||||||
|
|
||||||
|
if shared_mock_data["doc_iter"] is None:
|
||||||
|
docs = shared_mock_data["documents"] or [None]
|
||||||
|
shared_mock_data["doc_iter"] = iter(docs)
|
||||||
|
|
||||||
|
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||||
|
docs_or_empty = shared_mock_data["documents"] or []
|
||||||
|
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
else:
|
||||||
|
query.where = MagicMock(return_value=query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
# Forward all attribute access to the first session, or default if none created yet
|
||||||
|
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||||
|
return getattr(target_session, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_sessions(self):
|
||||||
|
"""Access all created sessions for testing."""
|
||||||
|
return self._sessions
|
||||||
|
|
||||||
|
wrapper = SessionWrapper()
|
||||||
|
yield wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -252,18 +356,9 @@ class TestTaskEnqueuing:
|
|||||||
use the deprecated function.
|
use the deprecated function.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
# Return documents one by one for each call
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -304,21 +399,9 @@ class TestBatchProcessing:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
# Create an iterator for documents
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
doc_iter = iter(mock_documents)
|
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
# Return documents one by one for each call
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -357,19 +440,9 @@ class TestBatchProcessing:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||||
@ -407,19 +480,9 @@ class TestBatchProcessing:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||||
@ -444,7 +507,10 @@ class TestBatchProcessing:
|
|||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
document_ids = []
|
document_ids = []
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
|
||||||
|
# Set shared mock data with empty documents list
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
|
mock_db_session._shared_data["documents"] = []
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -482,19 +548,9 @@ class TestProgressTracking:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -528,19 +584,9 @@ class TestProgressTracking:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -635,19 +681,9 @@ class TestErrorHandling:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set up to trigger vector space limit error
|
# Set up to trigger vector space limit error
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@ -674,17 +710,9 @@ class TestErrorHandling:
|
|||||||
Errors during indexing should be caught and logged, but not crash the task.
|
Errors during indexing should be caught and logged, but not crash the task.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||||
@ -708,17 +736,9 @@ class TestErrorHandling:
|
|||||||
but not treated as a failure.
|
but not treated as a failure.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise DocumentIsPausedError
|
# Make IndexingRunner raise DocumentIsPausedError
|
||||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||||
@ -853,17 +873,9 @@ class TestTaskCancellation:
|
|||||||
Session cleanup should happen in finally block.
|
Session cleanup should happen in finally block.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -883,17 +895,9 @@ class TestTaskCancellation:
|
|||||||
Session cleanup should happen even when errors occur.
|
Session cleanup should happen even when errors occur.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||||
@ -962,6 +966,7 @@ class TestAdvancedScenarios:
|
|||||||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||||
|
|
||||||
# Create only 2 documents (simulate one missing)
|
# Create only 2 documents (simulate one missing)
|
||||||
|
# The new code uses .all() which will only return existing documents
|
||||||
mock_documents = []
|
mock_documents = []
|
||||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||||
doc = MagicMock(spec=Document)
|
doc = MagicMock(spec=Document)
|
||||||
@ -971,21 +976,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data - .all() will only return existing documents
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
# Create iterator that returns None for missing document
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
|
||||||
doc_iter = iter(doc_responses)
|
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set vector space exactly at limit
|
# Set vector space exactly at limit
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Billing disabled - limits should not be checked
|
# Billing disabled - limits should not be checked
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||||
@ -1273,19 +1246,9 @@ class TestIntegration:
|
|||||||
|
|
||||||
# Set up rpop to return None for concurrency check (no more tasks)
|
# Set up rpop to return None for concurrency check (no more tasks)
|
||||||
mock_redis.rpop.side_effect = [None]
|
mock_redis.rpop.side_effect = [None]
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1321,19 +1284,9 @@ class TestIntegration:
|
|||||||
|
|
||||||
# Set up rpop to return None for concurrency check (no more tasks)
|
# Set up rpop to return None for concurrency check (no more tasks)
|
||||||
mock_redis.rpop.side_effect = [None]
|
mock_redis.rpop.side_effect = [None]
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1415,17 +1368,9 @@ class TestEdgeCases:
|
|||||||
mock_document.indexing_status = "waiting"
|
mock_document.indexing_status = "waiting"
|
||||||
mock_document.processing_started_at = None
|
mock_document.processing_started_at = None
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = [mock_document]
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: mock_document
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1465,17 +1410,9 @@ class TestEdgeCases:
|
|||||||
mock_document.indexing_status = "waiting"
|
mock_document.indexing_status = "waiting"
|
||||||
mock_document.processing_started_at = None
|
mock_document.processing_started_at = None
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = [mock_document]
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: mock_document
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1555,19 +1492,9 @@ class TestEdgeCases:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set vector space limit to 0 (unlimited)
|
# Set vector space limit to 0 (unlimited)
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@ -1612,19 +1539,9 @@ class TestEdgeCases:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set negative vector space limit
|
# Set negative vector space limit
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Configure billing with sufficient limits
|
# Configure billing with sufficient limits
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@ -1826,19 +1733,9 @@ class TestRobustness:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||||
@ -1866,7 +1763,7 @@ class TestRobustness:
|
|||||||
- No exceptions occur
|
- No exceptions occur
|
||||||
|
|
||||||
Expected behavior:
|
Expected behavior:
|
||||||
- Database session is closed
|
- All database sessions are closed
|
||||||
- No connection leaks
|
- No connection leaks
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
@ -1879,19 +1776,9 @@ class TestRobustness:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@ -1899,10 +1786,11 @@ class TestRobustness:
|
|||||||
# Act
|
# Act
|
||||||
_document_indexing(dataset_id, document_ids)
|
_document_indexing(dataset_id, document_ids)
|
||||||
|
|
||||||
# Assert
|
# Assert - All created sessions should be closed
|
||||||
assert mock_db_session.close.called
|
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||||
# Verify close is called exactly once
|
assert len(mock_db_session.all_sessions) >= 1
|
||||||
assert mock_db_session.close.call_count == 1
|
for session in mock_db_session.all_sessions:
|
||||||
|
assert session.close.called, "All sessions should be closed"
|
||||||
|
|
||||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -109,25 +109,87 @@ def mock_document_segments(document_id):
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_db_session():
|
def mock_db_session():
|
||||||
"""Mock database session via session_factory.create_session()."""
|
"""Mock database session via session_factory.create_session().
|
||||||
|
|
||||||
|
After session split refactor, the code calls create_session() multiple times.
|
||||||
|
This fixture creates shared query mocks so all sessions use the same
|
||||||
|
query configuration, simulating database persistence across sessions.
|
||||||
|
|
||||||
|
The fixture automatically converts side_effect to cycle to prevent StopIteration.
|
||||||
|
Tests configure mocks the same way as before, but behind the scenes the values
|
||||||
|
are cycled infinitely for all sessions.
|
||||||
|
"""
|
||||||
|
from itertools import cycle
|
||||||
|
|
||||||
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
|
||||||
session = MagicMock()
|
sessions = []
|
||||||
# Ensure tests can observe session.close() via context manager teardown
|
|
||||||
session.close = MagicMock()
|
|
||||||
cm = MagicMock()
|
|
||||||
cm.__enter__.return_value = session
|
|
||||||
|
|
||||||
def _exit_side_effect(*args, **kwargs):
|
# Shared query mocks - all sessions use these
|
||||||
session.close()
|
shared_query = MagicMock()
|
||||||
|
shared_filter_by = MagicMock()
|
||||||
|
shared_scalars_result = MagicMock()
|
||||||
|
|
||||||
cm.__exit__.side_effect = _exit_side_effect
|
# Create custom first mock that auto-cycles side_effect
|
||||||
mock_sf.create_session.return_value = cm
|
class CyclicMock(MagicMock):
|
||||||
|
def __setattr__(self, name, value):
|
||||||
|
if name == "side_effect" and value is not None:
|
||||||
|
# Convert list/tuple to infinite cycle
|
||||||
|
if isinstance(value, (list, tuple)):
|
||||||
|
value = cycle(value)
|
||||||
|
super().__setattr__(name, value)
|
||||||
|
|
||||||
query = MagicMock()
|
shared_query.where.return_value.first = CyclicMock()
|
||||||
session.query.return_value = query
|
shared_filter_by.first = CyclicMock()
|
||||||
query.where.return_value = query
|
|
||||||
session.scalars.return_value = MagicMock()
|
def _create_session():
|
||||||
yield session
|
"""Create a new mock session for each create_session() call."""
|
||||||
|
session = MagicMock()
|
||||||
|
session.close = MagicMock()
|
||||||
|
session.commit = MagicMock()
|
||||||
|
|
||||||
|
# Mock session.begin() context manager
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = session
|
||||||
|
|
||||||
|
def _begin_exit_side_effect(exc_type, exc, tb):
|
||||||
|
# commit on success
|
||||||
|
if exc_type is None:
|
||||||
|
session.commit()
|
||||||
|
# return False to propagate exceptions
|
||||||
|
return False
|
||||||
|
|
||||||
|
begin_cm.__exit__.side_effect = _begin_exit_side_effect
|
||||||
|
session.begin.return_value = begin_cm
|
||||||
|
|
||||||
|
# Mock create_session() context manager
|
||||||
|
cm = MagicMock()
|
||||||
|
cm.__enter__.return_value = session
|
||||||
|
|
||||||
|
def _exit_side_effect(exc_type, exc, tb):
|
||||||
|
session.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
cm.__exit__.side_effect = _exit_side_effect
|
||||||
|
|
||||||
|
# All sessions use the same shared query mocks
|
||||||
|
session.query.return_value = shared_query
|
||||||
|
shared_query.where.return_value = shared_query
|
||||||
|
shared_query.filter_by.return_value = shared_filter_by
|
||||||
|
session.scalars.return_value = shared_scalars_result
|
||||||
|
|
||||||
|
sessions.append(session)
|
||||||
|
# Attach helpers on the first created session for assertions across all sessions
|
||||||
|
if len(sessions) == 1:
|
||||||
|
session.get_all_sessions = lambda: sessions
|
||||||
|
session.any_close_called = lambda: any(s.close.called for s in sessions)
|
||||||
|
session.any_commit_called = lambda: any(s.commit.called for s in sessions)
|
||||||
|
return cm
|
||||||
|
|
||||||
|
mock_sf.create_session.side_effect = _create_session
|
||||||
|
|
||||||
|
# Create first session and return it
|
||||||
|
_create_session()
|
||||||
|
yield sessions[0]
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -186,8 +248,8 @@ class TestDocumentIndexingSyncTask:
|
|||||||
# Act
|
# Act
|
||||||
document_indexing_sync_task(dataset_id, document_id)
|
document_indexing_sync_task(dataset_id, document_id)
|
||||||
|
|
||||||
# Assert
|
# Assert - at least one session should have been closed
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
def test_missing_notion_workspace_id(self, mock_db_session, mock_document, dataset_id, document_id):
|
||||||
"""Test that task raises error when notion_workspace_id is missing."""
|
"""Test that task raises error when notion_workspace_id is missing."""
|
||||||
@ -230,6 +292,7 @@ class TestDocumentIndexingSyncTask:
|
|||||||
"""Test that task handles missing credentials by updating document status."""
|
"""Test that task handles missing credentials by updating document status."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
mock_datasource_provider_service.get_datasource_credentials.return_value = None
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -239,8 +302,8 @@ class TestDocumentIndexingSyncTask:
|
|||||||
assert mock_document.indexing_status == "error"
|
assert mock_document.indexing_status == "error"
|
||||||
assert "Datasource credential not found" in mock_document.error
|
assert "Datasource credential not found" in mock_document.error
|
||||||
assert mock_document.stopped_at is not None
|
assert mock_document.stopped_at is not None
|
||||||
mock_db_session.commit.assert_called()
|
assert mock_db_session.any_commit_called()
|
||||||
mock_db_session.close.assert_called()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_page_not_updated(
|
def test_page_not_updated(
|
||||||
self,
|
self,
|
||||||
@ -254,6 +317,7 @@ class TestDocumentIndexingSyncTask:
|
|||||||
"""Test that task does nothing when page has not been updated."""
|
"""Test that task does nothing when page has not been updated."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
mock_db_session.query.return_value.where.return_value.first.return_value = mock_document
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
# Return same time as stored in document
|
# Return same time as stored in document
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-01T00:00:00Z"
|
||||||
|
|
||||||
@ -263,8 +327,8 @@ class TestDocumentIndexingSyncTask:
|
|||||||
# Assert
|
# Assert
|
||||||
# Document status should remain unchanged
|
# Document status should remain unchanged
|
||||||
assert mock_document.indexing_status == "completed"
|
assert mock_document.indexing_status == "completed"
|
||||||
# Session should still be closed via context manager teardown
|
# At least one session should have been closed via context manager teardown
|
||||||
assert mock_db_session.close.called
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_successful_sync_when_page_updated(
|
def test_successful_sync_when_page_updated(
|
||||||
self,
|
self,
|
||||||
@ -281,7 +345,20 @@ class TestDocumentIndexingSyncTask:
|
|||||||
):
|
):
|
||||||
"""Test successful sync flow when Notion page has been updated."""
|
"""Test successful sync flow when Notion page has been updated."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
# Set exact sequence of returns across calls to `.first()`:
|
||||||
|
# 1) document (initial fetch)
|
||||||
|
# 2) dataset (pre-check)
|
||||||
|
# 3) dataset (cleaning phase)
|
||||||
|
# 4) document (pre-indexing update)
|
||||||
|
# 5) document (indexing runner fetch)
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||||
|
mock_document,
|
||||||
|
mock_dataset,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document,
|
||||||
|
]
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
# NotionExtractor returns updated time
|
# NotionExtractor returns updated time
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
@ -299,28 +376,40 @@ class TestDocumentIndexingSyncTask:
|
|||||||
mock_processor.clean.assert_called_once()
|
mock_processor.clean.assert_called_once()
|
||||||
|
|
||||||
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
|
||||||
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
|
# Aggregate execute calls across all created sessions
|
||||||
|
execute_sqls = []
|
||||||
|
for s in mock_db_session.get_all_sessions():
|
||||||
|
execute_sqls.extend([" ".join(str(c[0][0]).split()) for c in s.execute.call_args_list])
|
||||||
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
|
||||||
|
|
||||||
# Verify indexing runner was called
|
# Verify indexing runner was called
|
||||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||||
|
|
||||||
# Verify session operations
|
# Verify session operations (across any created session)
|
||||||
assert mock_db_session.commit.called
|
assert mock_db_session.any_commit_called()
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_dataset_not_found_during_cleaning(
|
def test_dataset_not_found_during_cleaning(
|
||||||
self,
|
self,
|
||||||
mock_db_session,
|
mock_db_session,
|
||||||
mock_datasource_provider_service,
|
mock_datasource_provider_service,
|
||||||
mock_notion_extractor,
|
mock_notion_extractor,
|
||||||
|
mock_indexing_runner,
|
||||||
mock_document,
|
mock_document,
|
||||||
dataset_id,
|
dataset_id,
|
||||||
document_id,
|
document_id,
|
||||||
):
|
):
|
||||||
"""Test that task handles dataset not found during cleaning phase."""
|
"""Test that task handles dataset not found during cleaning phase."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, None]
|
# Sequence: document (initial), dataset (pre-check), None (cleaning), document (update), document (indexing)
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||||
|
mock_document,
|
||||||
|
mock_dataset,
|
||||||
|
None,
|
||||||
|
mock_document,
|
||||||
|
mock_document,
|
||||||
|
]
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -329,8 +418,8 @@ class TestDocumentIndexingSyncTask:
|
|||||||
# Assert
|
# Assert
|
||||||
# Document should still be set to parsing
|
# Document should still be set to parsing
|
||||||
assert mock_document.indexing_status == "parsing"
|
assert mock_document.indexing_status == "parsing"
|
||||||
# Session should be closed after error
|
# At least one session should be closed after error
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_cleaning_error_continues_to_indexing(
|
def test_cleaning_error_continues_to_indexing(
|
||||||
self,
|
self,
|
||||||
@ -346,8 +435,14 @@ class TestDocumentIndexingSyncTask:
|
|||||||
):
|
):
|
||||||
"""Test that indexing continues even if cleaning fails."""
|
"""Test that indexing continues even if cleaning fails."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
from itertools import cycle
|
||||||
mock_db_session.scalars.return_value.all.side_effect = Exception("Cleaning error")
|
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
|
# Make the cleaning step fail but not the segment fetch
|
||||||
|
processor = mock_index_processor_factory.return_value.init_index_processor.return_value
|
||||||
|
processor.clean.side_effect = Exception("Cleaning error")
|
||||||
|
mock_db_session.scalars.return_value.all.return_value = []
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
# Act
|
# Act
|
||||||
@ -356,7 +451,7 @@ class TestDocumentIndexingSyncTask:
|
|||||||
# Assert
|
# Assert
|
||||||
# Indexing should still be attempted despite cleaning error
|
# Indexing should still be attempted despite cleaning error
|
||||||
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
mock_indexing_runner.run.assert_called_once_with([mock_document])
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_indexing_runner_document_paused_error(
|
def test_indexing_runner_document_paused_error(
|
||||||
self,
|
self,
|
||||||
@ -373,7 +468,10 @@ class TestDocumentIndexingSyncTask:
|
|||||||
):
|
):
|
||||||
"""Test that DocumentIsPausedError is handled gracefully."""
|
"""Test that DocumentIsPausedError is handled gracefully."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
from itertools import cycle
|
||||||
|
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
|
||||||
@ -383,7 +481,7 @@ class TestDocumentIndexingSyncTask:
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Session should be closed after handling error
|
# Session should be closed after handling error
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_indexing_runner_general_error(
|
def test_indexing_runner_general_error(
|
||||||
self,
|
self,
|
||||||
@ -400,7 +498,10 @@ class TestDocumentIndexingSyncTask:
|
|||||||
):
|
):
|
||||||
"""Test that general exceptions during indexing are handled."""
|
"""Test that general exceptions during indexing are handled."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
from itertools import cycle
|
||||||
|
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = cycle([mock_document, mock_dataset])
|
||||||
|
mock_db_session.query.return_value.filter_by.return_value.first.return_value = mock_document
|
||||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
mock_indexing_runner.run.side_effect = Exception("Indexing error")
|
||||||
@ -410,7 +511,7 @@ class TestDocumentIndexingSyncTask:
|
|||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
# Session should be closed after error
|
# Session should be closed after error
|
||||||
mock_db_session.close.assert_called_once()
|
assert mock_db_session.any_close_called()
|
||||||
|
|
||||||
def test_notion_extractor_initialized_with_correct_params(
|
def test_notion_extractor_initialized_with_correct_params(
|
||||||
self,
|
self,
|
||||||
@ -517,7 +618,14 @@ class TestDocumentIndexingSyncTask:
|
|||||||
):
|
):
|
||||||
"""Test that index processor clean is called with correct parameters."""
|
"""Test that index processor clean is called with correct parameters."""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_document, mock_dataset]
|
# Sequence: document (initial), dataset (pre-check), dataset (cleaning), document (update), document (indexing)
|
||||||
|
mock_db_session.query.return_value.where.return_value.first.side_effect = [
|
||||||
|
mock_document,
|
||||||
|
mock_dataset,
|
||||||
|
mock_dataset,
|
||||||
|
mock_document,
|
||||||
|
mock_document,
|
||||||
|
]
|
||||||
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
|
||||||
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
mock_notion_extractor.get_notion_last_edited_time.return_value = "2024-01-02T00:00:00Z"
|
||||||
|
|
||||||
|
|||||||
@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
|
|||||||
mock_query.where.return_value = mock_delete_query
|
mock_query.where.return_value = mock_delete_query
|
||||||
mock_db.session.query.return_value = mock_query
|
mock_db.session.query.return_value = mock_query
|
||||||
|
|
||||||
delete_func("log-1")
|
delete_func(mock_db.session, "log-1")
|
||||||
|
|
||||||
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
|
||||||
mock_query.where.assert_called_once()
|
mock_query.where.assert_called_once()
|
||||||
|
|||||||
19
api/uv.lock
generated
19
api/uv.lock
generated
@ -1368,7 +1368,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dify-api"
|
name = "dify-api"
|
||||||
version = "1.12.0"
|
version = "1.12.1"
|
||||||
source = { virtual = "." }
|
source = { virtual = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aliyun-log-python-sdk" },
|
{ name = "aliyun-log-python-sdk" },
|
||||||
@ -1653,7 +1653,7 @@ requires-dist = [
|
|||||||
{ name = "starlette", specifier = "==0.49.1" },
|
{ name = "starlette", specifier = "==0.49.1" },
|
||||||
{ name = "tiktoken", specifier = "~=0.9.0" },
|
{ name = "tiktoken", specifier = "~=0.9.0" },
|
||||||
{ name = "transformers", specifier = "~=4.56.1" },
|
{ name = "transformers", specifier = "~=4.56.1" },
|
||||||
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" },
|
{ name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" },
|
||||||
{ name = "weave", specifier = ">=0.52.16" },
|
{ name = "weave", specifier = ">=0.52.16" },
|
||||||
{ name = "weaviate-client", specifier = "==4.17.0" },
|
{ name = "weaviate-client", specifier = "==4.17.0" },
|
||||||
{ name = "webvtt-py", specifier = "~=0.5.1" },
|
{ name = "webvtt-py", specifier = "~=0.5.1" },
|
||||||
@ -4433,15 +4433,15 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pdfminer-six"
|
name = "pdfminer-six"
|
||||||
version = "20251230"
|
version = "20260107"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "charset-normalizer" },
|
{ name = "charset-normalizer" },
|
||||||
{ name = "cryptography" },
|
{ name = "cryptography" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" },
|
{ url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -6814,12 +6814,12 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unstructured"
|
name = "unstructured"
|
||||||
version = "0.16.25"
|
version = "0.18.31"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "backoff" },
|
{ name = "backoff" },
|
||||||
{ name = "beautifulsoup4" },
|
{ name = "beautifulsoup4" },
|
||||||
{ name = "chardet" },
|
{ name = "charset-normalizer" },
|
||||||
{ name = "dataclasses-json" },
|
{ name = "dataclasses-json" },
|
||||||
{ name = "emoji" },
|
{ name = "emoji" },
|
||||||
{ name = "filetype" },
|
{ name = "filetype" },
|
||||||
@ -6827,6 +6827,7 @@ dependencies = [
|
|||||||
{ name = "langdetect" },
|
{ name = "langdetect" },
|
||||||
{ name = "lxml" },
|
{ name = "lxml" },
|
||||||
{ name = "nltk" },
|
{ name = "nltk" },
|
||||||
|
{ name = "numba" },
|
||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "psutil" },
|
{ name = "psutil" },
|
||||||
{ name = "python-iso639" },
|
{ name = "python-iso639" },
|
||||||
@ -6839,9 +6840,9 @@ dependencies = [
|
|||||||
{ name = "unstructured-client" },
|
{ name = "unstructured-client" },
|
||||||
{ name = "wrapt" },
|
{ name = "wrapt" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" }
|
sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" },
|
{ url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
|||||||
@ -21,7 +21,7 @@ services:
|
|||||||
|
|
||||||
# API service
|
# API service
|
||||||
api:
|
api:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -63,7 +63,7 @@ services:
|
|||||||
# worker service
|
# worker service
|
||||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||||
worker:
|
worker:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -102,7 +102,7 @@ services:
|
|||||||
# worker_beat service
|
# worker_beat service
|
||||||
# Celery beat for scheduling periodic tasks.
|
# Celery beat for scheduling periodic tasks.
|
||||||
worker_beat:
|
worker_beat:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -132,7 +132,7 @@ services:
|
|||||||
|
|
||||||
# Frontend web application.
|
# Frontend web application.
|
||||||
web:
|
web:
|
||||||
image: langgenius/dify-web:1.12.0
|
image: langgenius/dify-web:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||||
|
|||||||
@ -707,7 +707,7 @@ services:
|
|||||||
|
|
||||||
# API service
|
# API service
|
||||||
api:
|
api:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -749,7 +749,7 @@ services:
|
|||||||
# worker service
|
# worker service
|
||||||
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
|
||||||
worker:
|
worker:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -788,7 +788,7 @@ services:
|
|||||||
# worker_beat service
|
# worker_beat service
|
||||||
# Celery beat for scheduling periodic tasks.
|
# Celery beat for scheduling periodic tasks.
|
||||||
worker_beat:
|
worker_beat:
|
||||||
image: langgenius/dify-api:1.12.0
|
image: langgenius/dify-api:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
# Use the shared environment variables.
|
# Use the shared environment variables.
|
||||||
@ -818,7 +818,7 @@ services:
|
|||||||
|
|
||||||
# Frontend web application.
|
# Frontend web application.
|
||||||
web:
|
web:
|
||||||
image: langgenius/dify-web:1.12.0
|
image: langgenius/dify-web:1.12.1
|
||||||
restart: always
|
restart: always
|
||||||
environment:
|
environment:
|
||||||
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
CONSOLE_API_URL: ${CONSOLE_API_URL:-}
|
||||||
|
|||||||
@ -109,6 +109,7 @@ const AgentTools: FC = () => {
|
|||||||
tool_parameters: paramsWithDefaultValue,
|
tool_parameters: paramsWithDefaultValue,
|
||||||
notAuthor: !tool.is_team_authorization,
|
notAuthor: !tool.is_team_authorization,
|
||||||
enabled: true,
|
enabled: true,
|
||||||
|
type: tool.provider_type as CollectionType,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const handleSelectTool = (tool: ToolDefaultValue) => {
|
const handleSelectTool = (tool: ToolDefaultValue) => {
|
||||||
|
|||||||
@ -19,19 +19,21 @@ import { useBoolean, useSessionStorageState } from 'ahooks'
|
|||||||
import * as React from 'react'
|
import * as React from 'react'
|
||||||
import { useCallback, useEffect, useState } from 'react'
|
import { useCallback, useEffect, useState } from 'react'
|
||||||
import { useTranslation } from 'react-i18next'
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import { useShallow } from 'zustand/react/shallow'
|
||||||
import Button from '@/app/components/base/button'
|
import Button from '@/app/components/base/button'
|
||||||
import Confirm from '@/app/components/base/confirm'
|
import Confirm from '@/app/components/base/confirm'
|
||||||
import { Generator } from '@/app/components/base/icons/src/vender/other'
|
import { Generator } from '@/app/components/base/icons/src/vender/other'
|
||||||
import Loading from '@/app/components/base/loading'
|
|
||||||
|
|
||||||
|
import Loading from '@/app/components/base/loading'
|
||||||
import Modal from '@/app/components/base/modal'
|
import Modal from '@/app/components/base/modal'
|
||||||
import Toast from '@/app/components/base/toast'
|
import Toast from '@/app/components/base/toast'
|
||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
||||||
|
|
||||||
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
|
||||||
import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug'
|
import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug'
|
||||||
import { useGenerateRuleTemplate } from '@/service/use-apps'
|
import { useGenerateRuleTemplate } from '@/service/use-apps'
|
||||||
|
import { useStore } from '../../../store'
|
||||||
import IdeaOutput from './idea-output'
|
import IdeaOutput from './idea-output'
|
||||||
import InstructionEditorInBasic from './instruction-editor'
|
import InstructionEditorInBasic from './instruction-editor'
|
||||||
import InstructionEditorInWorkflow from './instruction-editor-in-workflow'
|
import InstructionEditorInWorkflow from './instruction-editor-in-workflow'
|
||||||
@ -83,6 +85,9 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
|||||||
onFinished,
|
onFinished,
|
||||||
}) => {
|
}) => {
|
||||||
const { t } = useTranslation()
|
const { t } = useTranslation()
|
||||||
|
const { appDetail } = useStore(useShallow(state => ({
|
||||||
|
appDetail: state.appDetail,
|
||||||
|
})))
|
||||||
const localModel = localStorage.getItem('auto-gen-model')
|
const localModel = localStorage.getItem('auto-gen-model')
|
||||||
? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model
|
? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model
|
||||||
: null
|
: null
|
||||||
@ -235,6 +240,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
|||||||
instruction,
|
instruction,
|
||||||
model_config: model,
|
model_config: model,
|
||||||
no_variable: false,
|
no_variable: false,
|
||||||
|
app_id: appDetail?.id,
|
||||||
})
|
})
|
||||||
apiRes = {
|
apiRes = {
|
||||||
...res,
|
...res,
|
||||||
@ -256,6 +262,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
|||||||
instruction,
|
instruction,
|
||||||
ideal_output: ideaOutput,
|
ideal_output: ideaOutput,
|
||||||
model_config: model,
|
model_config: model,
|
||||||
|
app_id: appDetail?.id,
|
||||||
})
|
})
|
||||||
apiRes = res
|
apiRes = res
|
||||||
if (error) {
|
if (error) {
|
||||||
|
|||||||
@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
{
|
{
|
||||||
showSummaryIndexSetting && (
|
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<SummaryIndexSetting
|
<SummaryIndexSetting
|
||||||
entry="create-document"
|
entry="create-document"
|
||||||
|
|||||||
@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
|
|||||||
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
|
||||||
import RadioCard from '@/app/components/base/radio-card'
|
import RadioCard from '@/app/components/base/radio-card'
|
||||||
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
|
||||||
|
import { IS_CE_EDITION } from '@/config'
|
||||||
import { ChunkingMode } from '@/models/datasets'
|
import { ChunkingMode } from '@/models/datasets'
|
||||||
import FileList from '../../assets/file-list-3-fill.svg'
|
import FileList from '../../assets/file-list-3-fill.svg'
|
||||||
import Note from '../../assets/note-mod.svg'
|
import Note from '../../assets/note-mod.svg'
|
||||||
@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
|
|||||||
</div>
|
</div>
|
||||||
))}
|
))}
|
||||||
{
|
{
|
||||||
showSummaryIndexSetting && (
|
showSummaryIndexSetting && IS_CE_EDITION && (
|
||||||
<div className="mt-3">
|
<div className="mt-3">
|
||||||
<SummaryIndexSetting
|
<SummaryIndexSetting
|
||||||
entry="create-document"
|
entry="create-document"
|
||||||
|
|||||||
@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
|
|||||||
import Switch from '@/app/components/base/switch'
|
import Switch from '@/app/components/base/switch'
|
||||||
import { ToastContext } from '@/app/components/base/toast'
|
import { ToastContext } from '@/app/components/base/toast'
|
||||||
import Tooltip from '@/app/components/base/tooltip'
|
import Tooltip from '@/app/components/base/tooltip'
|
||||||
|
import { IS_CE_EDITION } from '@/config'
|
||||||
import { DataSourceType, DocumentActionType } from '@/models/datasets'
|
import { DataSourceType, DocumentActionType } from '@/models/datasets'
|
||||||
import {
|
import {
|
||||||
useDocumentArchive,
|
useDocumentArchive,
|
||||||
@ -263,10 +264,14 @@ const Operations = ({
|
|||||||
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
|
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
{
|
||||||
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
IS_CE_EDITION && (
|
||||||
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
<div className={s.actionItem} onClick={() => onOperate('summary')}>
|
||||||
</div>
|
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
|
||||||
|
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
|
||||||
|
</div>
|
||||||
|
)
|
||||||
|
}
|
||||||
<Divider className="my-1" />
|
<Divider className="my-1" />
|
||||||
</>
|
</>
|
||||||
)}
|
)}
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
|
|||||||
import Confirm from '@/app/components/base/confirm'
|
import Confirm from '@/app/components/base/confirm'
|
||||||
import Divider from '@/app/components/base/divider'
|
import Divider from '@/app/components/base/divider'
|
||||||
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
|
||||||
|
import { IS_CE_EDITION } from '@/config'
|
||||||
import { cn } from '@/utils/classnames'
|
import { cn } from '@/utils/classnames'
|
||||||
|
|
||||||
const i18nPrefix = 'batchAction'
|
const i18nPrefix = 'batchAction'
|
||||||
@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
|
|||||||
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
|
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
|
||||||
</Button>
|
</Button>
|
||||||
)}
|
)}
|
||||||
{onBatchSummary && (
|
{onBatchSummary && IS_CE_EDITION && (
|
||||||
<Button
|
<Button
|
||||||
variant="ghost"
|
variant="ghost"
|
||||||
className="gap-x-0.5 px-3"
|
className="gap-x-0.5 px-3"
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
|
|||||||
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||||
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||||
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
|
||||||
|
import { IS_CE_EDITION } from '@/config'
|
||||||
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
|
||||||
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
|
||||||
import { useDocLink } from '@/context/i18n'
|
import { useDocLink } from '@/context/i18n'
|
||||||
@ -359,7 +360,7 @@ const Form = () => {
|
|||||||
{
|
{
|
||||||
indexMethod === IndexingType.QUALIFIED
|
indexMethod === IndexingType.QUALIFIED
|
||||||
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
|
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
|
||||||
&& (
|
&& IS_CE_EDITION && (
|
||||||
<>
|
<>
|
||||||
<Divider
|
<Divider
|
||||||
type="horizontal"
|
type="horizontal"
|
||||||
|
|||||||
@ -104,7 +104,7 @@ const MembersPage = () => {
|
|||||||
<UpgradeBtn className="mr-2" loc="member-invite" />
|
<UpgradeBtn className="mr-2" loc="member-invite" />
|
||||||
)}
|
)}
|
||||||
<div className="shrink-0">
|
<div className="shrink-0">
|
||||||
<InviteButton disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)} />
|
{isCurrentWorkspaceManager && <InviteButton disabled={isMemberFull} onClick={() => setInviteModalVisible(true)} />}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="overflow-visible lg:overflow-visible">
|
<div className="overflow-visible lg:overflow-visible">
|
||||||
|
|||||||
@ -129,6 +129,7 @@ export const useToolSelectorState = ({
|
|||||||
extra: {
|
extra: {
|
||||||
description: tool.tool_description,
|
description: tool.tool_description,
|
||||||
},
|
},
|
||||||
|
type: tool.provider_type,
|
||||||
}
|
}
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
|||||||
@ -87,6 +87,7 @@ export type ToolValue = {
|
|||||||
enabled?: boolean
|
enabled?: boolean
|
||||||
extra?: { description?: string } & Record<string, unknown>
|
extra?: { description?: string } & Record<string, unknown>
|
||||||
credential_id?: string
|
credential_id?: string
|
||||||
|
type?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
export type DataSourceItem = {
|
export type DataSourceItem = {
|
||||||
|
|||||||
@ -18,6 +18,7 @@ import {
|
|||||||
Group,
|
Group,
|
||||||
} from '@/app/components/workflow/nodes/_base/components/layout'
|
} from '@/app/components/workflow/nodes/_base/components/layout'
|
||||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||||
|
import { IS_CE_EDITION } from '@/config'
|
||||||
import Split from '../_base/components/split'
|
import Split from '../_base/components/split'
|
||||||
import ChunkStructure from './components/chunk-structure'
|
import ChunkStructure from './components/chunk-structure'
|
||||||
import EmbeddingModel from './components/embedding-model'
|
import EmbeddingModel from './components/embedding-model'
|
||||||
@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
|
|||||||
{
|
{
|
||||||
data.indexing_technique === IndexMethodEnum.QUALIFIED
|
data.indexing_technique === IndexMethodEnum.QUALIFIED
|
||||||
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
|
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
|
||||||
&& (
|
&& IS_CE_EDITION && (
|
||||||
<>
|
<>
|
||||||
<SummaryIndexSetting
|
<SummaryIndexSetting
|
||||||
summaryIndexSetting={data.summary_index_setting}
|
summaryIndexSetting={data.summary_index_setting}
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"name": "dify-web",
|
"name": "dify-web",
|
||||||
"type": "module",
|
"type": "module",
|
||||||
"version": "1.12.0",
|
"version": "1.12.1",
|
||||||
"private": true,
|
"private": true,
|
||||||
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
|
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
|
||||||
"imports": {
|
"imports": {
|
||||||
|
|||||||
2205
web/pnpm-lock.yaml
generated
2205
web/pnpm-lock.yaml
generated
File diff suppressed because it is too large
Load Diff
@ -9,7 +9,6 @@ import type {
|
|||||||
} from '@/types/workflow'
|
} from '@/types/workflow'
|
||||||
import { get, post } from './base'
|
import { get, post } from './base'
|
||||||
import { getFlowPrefix } from './utils'
|
import { getFlowPrefix } from './utils'
|
||||||
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
|
|
||||||
|
|
||||||
export const fetchWorkflowDraft = (url: string) => {
|
export const fetchWorkflowDraft = (url: string) => {
|
||||||
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
|
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
|
||||||
@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
|
|||||||
url: string
|
url: string
|
||||||
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
|
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
|
||||||
}) => {
|
}) => {
|
||||||
const sanitized = sanitizeWorkflowDraftPayload(params)
|
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
|
||||||
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export const fetchNodesDefaultConfigs = (url: string) => {
|
export const fetchNodesDefaultConfigs = (url: string) => {
|
||||||
|
|||||||
Reference in New Issue
Block a user