Merge branch 'main' into feat/add-trial-model-list

This commit is contained in:
zyssyz123
2026-01-22 15:56:44 +08:00
committed by GitHub
365 changed files with 6719 additions and 4035 deletions

View File

@ -715,4 +715,5 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000

View File

@ -1,52 +0,0 @@
## Purpose
`api/controllers/console/datasets/datasets_document.py` contains the console (authenticated) APIs for managing dataset documents (list/create/update/delete, processing controls, estimates, etc.).
## Storage model (uploaded files)
- For local file uploads into a knowledge base, the binary is stored via `extensions.ext_storage.storage` under the key:
- `upload_files/<tenant_id>/<uuid>.<ext>`
- File metadata is stored in the `upload_files` table (`UploadFile` model), keyed by `UploadFile.id`.
- Dataset `Document` records reference the uploaded file via:
- `Document.data_source_info.upload_file_id`
## Download endpoint
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
- Only supported when `Document.data_source_type == "upload_file"`.
- Performs dataset permission + tenant checks via `DocumentResource.get_document(...)`.
- Delegates `Document -> UploadFile` validation and signed URL generation to `DocumentService.get_document_download_url(...)`.
- Applies `cloud_edition_billing_rate_limit_check("knowledge")` to match other KB operations.
- Response body is **only**: `{ "url": "<signed-url>" }`.
- `POST /datasets/<dataset_id>/documents/download-zip`
- Accepts `{ "document_ids": ["..."] }` (upload-file only).
- Returns `application/zip` as a single attachment download.
- Rationale: browsers often block multiple automatic downloads; a ZIP avoids that limitation.
- Applies `cloud_edition_billing_rate_limit_check("knowledge")`.
- Delegates dataset permission checks, document/upload-file validation, and download-name generation to
`DocumentService.prepare_document_batch_download_zip(...)` before streaming the ZIP.
## Verification plan
- Upload a document from a local file into a dataset.
- Call the download endpoint and confirm it returns a signed URL.
- Open the URL and confirm:
- Response headers force download (`Content-Disposition`), and
- Downloaded bytes match the uploaded file.
- Select multiple uploaded-file documents and download as ZIP; confirm all selected files exist in the archive.
## Shared helper
- `DocumentService.get_document_download_url(document)` resolves the `UploadFile` and signs a download URL.
- `DocumentService.prepare_document_batch_download_zip(...)` performs dataset permission checks, batches
document + upload file lookups, preserves request order, and generates the client-visible ZIP filename.
- Internal helpers now live in `DocumentService` (`_get_upload_file_id_for_upload_file_document(...)`,
`_get_upload_file_for_upload_file_document(...)`, `_get_upload_files_by_document_id_for_zip_download(...)`).
- ZIP packing is handled by `FileService.build_upload_files_zip_tempfile(...)`, which also:
- sanitizes entry names to avoid path traversal, and
- deduplicates names while preserving extensions (e.g., `doc.txt``doc (1).txt`).
Streaming the response and deferring cleanup is handled by the route via `send_file(path, ...)` + `ExitStack` +
`response.call_on_close(...)` (the file is deleted when the response is closed).

View File

@ -1,18 +0,0 @@
## Purpose
`api/services/dataset_service.py` hosts dataset/document service logic used by console and API controllers.
## Batch document operations
- Batch document workflows should avoid N+1 database queries by using set-based lookups.
- Tenant checks must be enforced consistently across dataset/document operations.
- `DocumentService.get_documents_by_ids(...)` fetches documents for a dataset using `id.in_(...)`.
- `FileService.get_upload_files_by_ids(...)` performs tenant-scoped batch lookup for `UploadFile` (dedupes ids with `set(...)`).
- `DocumentService.get_document_download_url(...)` and `prepare_document_batch_download_zip(...)` handle
dataset/document permission checks plus `Document -> UploadFile` validation for download endpoints.
## Verification plan
- Exercise document list and download endpoints that use the service helpers.
- Confirm batch download uses constant query count for documents + upload files.
- Request a ZIP with a missing document id and confirm a 404 is returned.

View File

@ -1,35 +0,0 @@
## Purpose
`api/services/file_service.py` owns business logic around `UploadFile` objects: upload validation, storage persistence,
previews/generators, and deletion.
## Key invariants
- All storage I/O goes through `extensions.ext_storage.storage`.
- Uploaded file keys follow: `upload_files/<tenant_id>/<uuid>.<ext>`.
- Upload validation is enforced in `FileService.upload_file(...)` (blocked extensions, size limits, dataset-only types).
## Batch lookup helpers
- `FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)` is the canonical tenant-scoped batch loader for
`UploadFile`.
## Dataset document download helpers
The dataset document download/ZIP endpoints now delegate “Document → UploadFile” validation and permission checks to
`DocumentService` (`api/services/dataset_service.py`). `FileService` stays focused on generic `UploadFile` operations
(uploading, previews, deletion), plus generic ZIP serving.
### ZIP serving
- `FileService.build_upload_files_zip_tempfile(...)` builds a ZIP from `UploadFile` objects and yields a seeked
tempfile **path** so callers can stream it (e.g., `send_file(path, ...)`) without hitting "read of closed file"
issues from file-handle lifecycle during streamed responses.
- Flask `send_file(...)` and the `ExitStack`/`call_on_close(...)` cleanup pattern are handled in the route layer.
## Verification plan
- Unit: `api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py`
- Verify signed URL generation for upload-file documents and ZIP download behavior for multiple documents.
- Unit: `api/tests/unit_tests/services/test_file_service_zip_and_lookup.py`
- Verify ZIP packing produces a valid, openable archive and preserves file content.

View File

@ -1,28 +0,0 @@
## Purpose
Unit tests for the console dataset document download endpoint:
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
## Testing approach
- Uses `Flask.test_request_context()` and calls the `Resource.get(...)` method directly.
- Monkeypatches console decorators (`login_required`, `setup_required`, rate limit) to no-ops to keep the test focused.
- Mocks:
- `DatasetService.get_dataset` / `check_dataset_permission`
- `DocumentService.get_document` for single-file download tests
- `DocumentService.get_documents_by_ids` + `FileService.get_upload_files_by_ids` for ZIP download tests
- `FileService.get_upload_files_by_ids` for `UploadFile` lookups in single-file tests
- `services.dataset_service.file_helpers.get_signed_file_url` to return a deterministic URL
- Document mocks include `id` fields so batch lookups can map documents by id.
## Covered cases
- Success returns `{ "url": "<signed>" }` for upload-file documents.
- 404 when document is not `upload_file`.
- 404 when `upload_file_id` is missing.
- 404 when referenced `UploadFile` row does not exist.
- 403 when document tenant does not match current tenant.
- Batch ZIP download returns `application/zip` for upload-file documents.
- Batch ZIP download rejects non-upload-file documents.
- Batch ZIP download uses a random `.zip` attachment name (`download_name`), so tests only assert the suffix.

View File

@ -1,18 +0,0 @@
## Purpose
Unit tests for `api/services/file_service.py` helper methods that are not covered by higher-level controller tests.
## Whats covered
- `FileService.build_upload_files_zip_tempfile(...)`
- ZIP entry name sanitization (no directory components / traversal)
- name deduplication while preserving extensions
- writing streamed bytes from `storage.load(...)` into ZIP entries
- yields a tempfile path so callers can open/stream the ZIP without holding a live file handle
- `FileService.get_upload_files_by_ids(...)`
- returns `{}` for empty id lists
- returns an id-keyed mapping for non-empty lists
## Notes
- These tests intentionally stub `storage.load` and `db.session.scalars(...).all()` to avoid needing a real DB/storage.

View File

@ -1,96 +0,0 @@
## Configuration
- Import `configs.dify_config` for every runtime toggle. Do not read environment variables directly.
- Add new settings to the proper mixin inside `configs/` (deployment, feature, middleware, etc.) so they load through `DifyConfig`.
- Remote overrides come from the optional providers in `configs/remote_settings_sources`; keep defaults in code safe when the value is missing.
- Example: logging pulls targets from `extensions/ext_logging.py`, and model provider URLs are assembled in `services/entities/model_provider_entities.py`.
## Dependencies
- Runtime dependencies live in `[project].dependencies` inside `pyproject.toml`. Optional clients go into the `storage`, `tools`, or `vdb` groups under `[dependency-groups]`.
- Always pin versions and keep the list alphabetised. Shared tooling (lint, typing, pytest) belongs in the `dev` group.
- When code needs a new package, explain why in the PR and run `uv lock` so the lockfile stays current.
## Storage & Files
- Use `extensions.ext_storage.storage` for all blob IO; it already respects the configured backend.
- Convert files for workflows with helpers in `core/file/file_manager.py`; they handle signed URLs and multimodal payloads.
- When writing controller logic, delegate upload quotas and metadata to `services/file_service.py` instead of touching storage directly.
- All outbound HTTP fetches (webhooks, remote files) must go through the SSRF-safe client in `core/helper/ssrf_proxy.py`; it wraps `httpx` with the allow/deny rules configured for the platform.
## Redis & Shared State
- Access Redis through `extensions.ext_redis.redis_client`. For locking, reuse `redis_client.lock`.
- Prefer higher-level helpers when available: rate limits use `libs.helper.RateLimiter`, provider metadata uses caches in `core/helper/provider_cache.py`.
## Models
- SQLAlchemy models sit in `models/` and inherit from the shared declarative `Base` defined in `models/base.py` (metadata configured via `models/engine.py`).
- `models/__init__.py` exposes grouped aggregates: account/tenant models, app and conversation tables, datasets, providers, workflow runs, triggers, etc. Import from there to avoid deep path churn.
- Follow the DDD boundary: persistence objects live in `models/`, repositories under `repositories/` translate them into domain entities, and services consume those repositories.
- When adding a table, create the model class, register it in `models/__init__.py`, wire a repository if needed, and generate an Alembic migration as described below.
## Vector Stores
- Vector client implementations live in `core/rag/datasource/vdb/<provider>`, with a common factory in `core/rag/datasource/vdb/vector_factory.py` and enums in `core/rag/datasource/vdb/vector_type.py`.
- Retrieval pipelines call these providers through `core/rag/datasource/retrieval_service.py` and dataset ingestion flows in `services/dataset_service.py`.
- The CLI helper `flask vdb-migrate` orchestrates bulk migrations using routines in `commands.py`; reuse that pattern when adding new backend transitions.
- To add another store, mirror the provider layout, register it with the factory, and include any schema changes in Alembic migrations.
## Observability & OTEL
- OpenTelemetry settings live under the observability mixin in `configs/observability`. Toggle exporters and sampling via `dify_config`, not ad-hoc env reads.
- HTTP, Celery, Redis, SQLAlchemy, and httpx instrumentation is initialised in `extensions/ext_app_metrics.py` and `extensions/ext_request_logging.py`; reuse these hooks when adding new workers or entrypoints.
- When creating background tasks or external calls, propagate tracing context with helpers in the existing instrumented clients (e.g. use the shared `httpx` session from `core/helper/http_client_pooling.py`).
- If you add a new external integration, ensure spans and metrics are emitted by wiring the appropriate OTEL instrumentation package in `pyproject.toml` and configuring it in `extensions/`.
## Ops Integrations
- Langfuse support and other tracing bridges live under `core/ops/opik_trace`. Config toggles sit in `configs/observability`, while exporters are initialised in the OTEL extensions mentioned above.
- External monitoring services should follow this pattern: keep client code in `core/ops`, expose switches via `dify_config`, and hook initialisation in `extensions/ext_app_metrics.py` or sibling modules.
- Before instrumenting new code paths, check whether existing context helpers (e.g. `extensions/ext_request_logging.py`) already capture the necessary metadata.
## Controllers, Services, Core
- Controllers only parse HTTP input and call a service method. Keep business rules in `services/`.
- Services enforce tenant rules, quotas, and orchestration, then call into `core/` engines (workflow execution, tools, LLMs).
- When adding a new endpoint, search for an existing service to extend before introducing a new layer. Example: workflow APIs pipe through `services/workflow_service.py` into `core/workflow`.
## Plugins, Tools, Providers
- In Dify a plugin is a tenant-installable bundle that declares one or more providers (tool, model, datasource, trigger, endpoint, agent strategy) plus its resource needs and version metadata. The manifest (`core/plugin/entities/plugin.py`) mirrors what you see in the marketplace documentation.
- Installation, upgrades, and migrations are orchestrated by `services/plugin/plugin_service.py` together with helpers such as `services/plugin/plugin_migration.py`.
- Runtime loading happens through the implementations under `core/plugin/impl/*` (tool/model/datasource/trigger/endpoint/agent). These modules normalise plugin providers so that downstream systems (`core/tools/tool_manager.py`, `services/model_provider_service.py`, `services/trigger/*`) can treat builtin and plugin capabilities the same way.
- For remote execution, plugin daemons (`core/plugin/entities/plugin_daemon.py`, `core/plugin/impl/plugin.py`) manage lifecycle hooks, credential forwarding, and background workers that keep plugin processes in sync with the main application.
- Acquire tool implementations through `core/tools/tool_manager.py`; it resolves builtin, plugin, and workflow-as-tool providers uniformly, injecting the right context (tenant, credentials, runtime config).
- To add a new plugin capability, extend the relevant `core/plugin/entities` schema and register the implementation in the matching `core/plugin/impl` module rather than importing the provider directly.
## Async Workloads
see `agent_skills/trigger.md` for more detailed documentation.
- Enqueue background work through `services/async_workflow_service.py`. It routes jobs to the tiered Celery queues defined in `tasks/`.
- Workers boot from `celery_entrypoint.py` and execute functions in `tasks/workflow_execution_tasks.py`, `tasks/trigger_processing_tasks.py`, etc.
- Scheduled workflows poll from `schedule/workflow_schedule_tasks.py`. Follow the same pattern if you need new periodic jobs.
## Database & Migrations
- SQLAlchemy models live under `models/` and map directly to migration files in `migrations/versions`.
- Generate migrations with `uv run --project api flask db revision --autogenerate -m "<summary>"`, then review the diff; never hand-edit the database outside Alembic.
- Apply migrations locally using `uv run --project api flask db upgrade`; production deploys expect the same history.
- If you add tenant-scoped data, confirm the upgrade includes tenant filters or defaults consistent with the service logic touching those tables.
## CLI Commands
- Maintenance commands from `commands.py` are registered on the Flask CLI. Run them via `uv run --project api flask <command>`.
- Use the built-in `db` commands from Flask-Migrate for schema operations (`flask db upgrade`, `flask db stamp`, etc.). Only fall back to custom helpers if you need their extra behaviour.
- Custom entries such as `flask reset-password`, `flask reset-email`, and `flask vdb-migrate` handle self-hosted account recovery and vector database migrations.
- Before adding a new command, check whether an existing service can be reused and ensure the command guards edition-specific behaviour (many enforce `SELF_HOSTED`). Document any additions in the PR.
- Ruff helpers are run directly with `uv`: `uv run --project api --dev ruff format ./api` for formatting and `uv run --project api --dev ruff check ./api` (add `--fix` if you want automatic fixes).
## When You Add Features
- Check for an existing helper or service before writing a new util.
- Uphold tenancy: every service method should receive the tenant ID from controller wrappers such as `controllers/console/wraps.py`.
- Update or create tests alongside behaviour changes (`tests/unit_tests` for fast coverage, `tests/integration_tests` when touching orchestrations).
- Run `uv run --project api --dev ruff check ./api`, `uv run --directory api --dev basedpyright`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh` before submitting changes.

View File

@ -1 +0,0 @@
// TBD

View File

@ -1 +0,0 @@
// TBD

View File

@ -1,53 +0,0 @@
## Overview
Trigger is a collection of nodes that we called `Start` nodes, also, the concept of `Start` is the same as `RootNode` in the workflow engine `core/workflow/graph_engine`, On the other hand, `Start` node is the entry point of workflows, every workflow run always starts from a `Start` node.
## Trigger nodes
- `UserInput`
- `Trigger Webhook`
- `Trigger Schedule`
- `Trigger Plugin`
### UserInput
Before `Trigger` concept is introduced, it's what we called `Start` node, but now, to avoid confusion, it was renamed to `UserInput` node, has a strong relation with `ServiceAPI` in `controllers/service_api/app`
1. `UserInput` node introduces a list of arguments that need to be provided by the user, finally it will be converted into variables in the workflow variable pool.
1. `ServiceAPI` accept those arguments, and pass through them into `UserInput` node.
1. For its detailed implementation, please refer to `core/workflow/nodes/start`
### Trigger Webhook
Inside Webhook Node, Dify provided a UI panel that allows user define a HTTP manifest `core/workflow/nodes/trigger_webhook/entities.py`.`WebhookData`, also, Dify generates a random webhook id for each `Trigger Webhook` node, the implementation was implemented in `core/trigger/utils/endpoint.py`, as you can see, `webhook-debug` is a debug mode for webhook, you may find it in `controllers/trigger/webhook.py`.
Finally, requests to `webhook` endpoint will be converted into variables in workflow variable pool during workflow execution.
### Trigger Schedule
`Trigger Schedule` node is a node that allows user define a schedule to trigger the workflow, detailed manifest is here `core/workflow/nodes/trigger_schedule/entities.py`, we have a poller and executor to handle millions of schedules, see `docker/entrypoint.sh` / `schedule/workflow_schedule_task.py` for help.
To Achieve this, a `WorkflowSchedulePlan` model was introduced in `models/trigger.py`, and a `events/event_handlers/sync_workflow_schedule_when_app_published.py` was used to sync workflow schedule plans when app is published.
### Trigger Plugin
`Trigger Plugin` node allows user define there own distributed trigger plugin, whenever a request was received, Dify forwards it to the plugin and wait for parsed variables from it.
1. Requests were saved in storage by `services/trigger/trigger_request_service.py`, referenced by `services/trigger/trigger_service.py`.`TriggerService`.`process_endpoint`
1. Plugins accept those requests and parse variables from it, see `core/plugin/impl/trigger.py` for details.
A `subscription` concept was out here by Dify, it means an endpoint address from Dify was bound to thirdparty webhook service like `Github` `Slack` `Linear` `GoogleDrive` `Gmail` etc. Once a subscription was created, Dify continually receives requests from the platforms and handle them one by one.
## Worker Pool / Async Task
All the events that triggered a new workflow run is always in async mode, a unified entrypoint can be found here `services/async_workflow_service.py`.`AsyncWorkflowService`.`trigger_workflow_async`.
The infrastructure we used is `celery`, we've already configured it in `docker/entrypoint.sh`, and the consumers are in `tasks/async_workflow_tasks.py`, 3 queues were used to handle different tiers of users, `PROFESSIONAL_QUEUE` `TEAM_QUEUE` `SANDBOX_QUEUE`.
## Debug Strategy
Dify divided users into 2 groups: builders / end users.
Builders are the users who create workflows, in this stage, debugging a workflow becomes a critical part of the workflow development process, as the start node in workflows, trigger nodes can `listen` to the events from `WebhookDebug` `Schedule` `Plugin`, debugging process was created in `controllers/console/app/workflow.py`.`DraftWorkflowTriggerNodeApi`.
A polling process can be considered as combine of few single `poll` operations, each `poll` operation fetches events cached in `Redis`, returns `None` if no event was found, more detailed implemented: `core/trigger/debug/event_bus.py` was used to handle the polling process, and `core/trigger/debug/event_selectors.py` was used to select the event poller based on the trigger type.

View File

@ -965,6 +965,16 @@ class MailConfig(BaseSettings):
default=None,
)
ENABLE_TRIAL_APP: bool = Field(
description="Enable trial app",
default=False,
)
ENABLE_EXPLORE_BANNER: bool = Field(
description="Enable explore banner",
default=False,
)
class RagEtlConfig(BaseSettings):
"""
@ -1298,6 +1308,10 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
description="Lock TTL for sandbox expired records clean task in seconds",
default=90000,
)
class FeatureConfig(

View File

@ -9,7 +9,7 @@ from typing import Any, final
from flask import Flask, current_app, g
from context import register_context_capturer
from core.workflow.context import register_context_capturer
from core.workflow.context.execution_context import (
AppContext,
IExecutionContext,

View File

@ -107,10 +107,12 @@ from .datasets.rag_pipeline import (
# Import explore controllers
from .explore import (
banner,
installed_app,
parameter,
recommended_app,
saved_message,
trial,
)
# Import tag controllers
@ -145,6 +147,7 @@ __all__ = [
"apikey",
"app",
"audio",
"banner",
"billing",
"bp",
"completion",
@ -198,6 +201,7 @@ __all__ = [
"statistic",
"tags",
"tool_providers",
"trial",
"trigger_providers",
"version",
"website",

View File

@ -15,7 +15,7 @@ from controllers.console.wraps import only_edition_cloud
from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
P = ParamSpec("P")
R = TypeVar("R")
@ -32,6 +32,8 @@ class InsertExploreAppPayload(BaseModel):
language: str = Field(...)
category: str = Field(...)
position: int = Field(...)
can_trial: bool = Field(default=False)
trial_limit: int = Field(default=0)
@field_validator("language")
@classmethod
@ -39,11 +41,33 @@ class InsertExploreAppPayload(BaseModel):
return supported_language(value)
class InsertExploreBannerPayload(BaseModel):
category: str = Field(...)
title: str = Field(...)
description: str = Field(...)
img_src: str = Field(..., alias="img-src")
language: str = Field(default="en-US")
link: str = Field(...)
sort: int = Field(...)
@field_validator("language")
@classmethod
def validate_language(cls, value: str) -> str:
return supported_language(value)
model_config = {"populate_by_name": True}
console_ns.schema_model(
InsertExploreAppPayload.__name__,
InsertExploreAppPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
InsertExploreBannerPayload.__name__,
InsertExploreBannerPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def admin_required(view: Callable[P, R]):
@wraps(view)
@ -109,6 +133,20 @@ class InsertExploreAppListApi(Resource):
)
db.session.add(recommended_app)
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -123,6 +161,20 @@ class InsertExploreAppListApi(Resource):
recommended_app.category = payload.category
recommended_app.position = payload.position
if payload.can_trial:
trial_app = db.session.execute(
select(TrialApp).where(TrialApp.app_id == payload.app_id)
).scalar_one_or_none()
if not trial_app:
db.session.add(
TrialApp(
app_id=payload.app_id,
tenant_id=app.tenant_id,
trial_limit=payload.trial_limit,
)
)
else:
trial_app.trial_limit = payload.trial_limit
app.is_public = True
db.session.commit()
@ -168,7 +220,62 @@ class InsertExploreAppApi(Resource):
for installed_app in installed_apps:
session.delete(installed_app)
trial_app = session.execute(
select(TrialApp).where(TrialApp.app_id == recommended_app.app_id)
).scalar_one_or_none()
if trial_app:
session.delete(trial_app)
db.session.delete(recommended_app)
db.session.commit()
return {"result": "success"}, 204
@console_ns.route("/admin/insert-explore-banner")
class InsertExploreBannerApi(Resource):
@console_ns.doc("insert_explore_banner")
@console_ns.doc(description="Insert an explore banner")
@console_ns.expect(console_ns.models[InsertExploreBannerPayload.__name__])
@console_ns.response(201, "Banner inserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content=content,
link=payload.link,
sort=payload.sort,
language=payload.language,
)
db.session.add(banner)
db.session.commit()
return {"result": "success"}, 201
@console_ns.route("/admin/delete-explore-banner/<uuid:banner_id>")
class DeleteExploreBannerApi(Resource):
@console_ns.doc("delete_explore_banner")
@console_ns.doc(description="Delete an explore banner")
@console_ns.doc(params={"banner_id": "Banner ID to delete"})
@console_ns.response(204, "Banner deleted successfully")
@only_edition_cloud
@admin_required
def delete(self, banner_id):
banner = db.session.execute(select(ExporleBanner).where(ExporleBanner.id == banner_id)).scalar_one_or_none()
if not banner:
raise NotFound(f"Banner '{banner_id}' is not found")
db.session.delete(banner)
db.session.commit()
return {"result": "success"}, 204

View File

@ -115,3 +115,9 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400

View File

@ -23,6 +23,11 @@ def _load_app_model(app_id: str) -> App | None:
return app_model
def _load_app_model_with_trial(app_id: str) -> App | None:
app_model = db.session.query(App).where(App.id == app_id, App.status == "normal").first()
return app_model
def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P1, R1]):
@wraps(view_func)
@ -62,3 +67,44 @@ def get_app_model(view: Callable[P, R] | None = None, *, mode: Union[AppMode, li
return decorator
else:
return decorator(view)
def get_app_model_with_trial(view: Callable[P, R] | None = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("app_id"):
raise ValueError("missing app_id in path parameters")
app_id = kwargs.get("app_id")
app_id = str(app_id)
del kwargs["app_id"]
app_model = _load_app_model_with_trial(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
if mode is not None:
if isinstance(mode, list):
modes = mode
else:
modes = [mode]
if app_mode not in modes:
mode_values = {m.value for m in modes}
raise AppNotFoundError(f"App mode is not in the supported list: {mode_values}")
kwargs["app_model"] = app_model
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@ -0,0 +1,43 @@
from flask import request
from flask_restx import Resource
from controllers.console import api
from controllers.console.explore.wraps import explore_banner_enabled
from extensions.ext_database import db
from models.model import ExporleBanner
class BannerApi(Resource):
"""Resource for banner list."""
@explore_banner_enabled
def get(self):
"""Get banner list."""
language = request.args.get("language", "en-US")
# Build base query for enabled banners
base_query = db.session.query(ExporleBanner).where(ExporleBanner.status == "enabled")
# Try to get banners in the requested language
banners = base_query.where(ExporleBanner.language == language).order_by(ExporleBanner.sort).all()
# Fallback to en-US if no banners found and language is not en-US
if not banners and language != "en-US":
banners = base_query.where(ExporleBanner.language == "en-US").order_by(ExporleBanner.sort).all()
# Convert banners to serializable format
result = []
for banner in banners:
banner_data = {
"id": banner.id,
"content": banner.content, # Already parsed as JSON by SQLAlchemy
"link": banner.link,
"sort": banner.sort,
"status": banner.status,
"created_at": banner.created_at.isoformat() if banner.created_at else None,
}
result.append(banner_data)
return result
api.add_resource(BannerApi, "/explore/banners")

View File

@ -29,3 +29,25 @@ class AppAccessDeniedError(BaseHTTPException):
error_code = "access_denied"
description = "App access denied."
code = 403
class TrialAppNotAllowed(BaseHTTPException):
"""*403* `Trial App Not Allowed`
Raise if the user has reached the trial app limit.
"""
error_code = "trial_app_not_allowed"
code = 403
description = "the app is not allowed to be trial."
class TrialAppLimitExceeded(BaseHTTPException):
"""*403* `Trial App Limit Exceeded`
Raise if the user has exceeded the trial app limit.
"""
error_code = "trial_app_limit_exceeded"
code = 403
description = "The user has exceeded the trial app limit."

View File

@ -29,6 +29,7 @@ recommended_app_fields = {
"category": fields.String,
"position": fields.Integer,
"is_listed": fields.Boolean,
"can_trial": fields.Boolean,
}
recommended_app_list_fields = {

View File

@ -0,0 +1,512 @@
import logging
from typing import Any, cast
from flask import request
from flask_restx import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
from controllers.common.fields import Parameters as ParametersResponse
from controllers.common.fields import Site as SiteResponse
from controllers.console import api
from controllers.console.app.error import (
AppUnavailableError,
AudioTooLargeError,
CompletionRequestError,
ConversationCompletedError,
NeedAddIdsError,
NoAudioUploadedError,
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderNotSupportSpeechToTextError,
ProviderQuotaExceededError,
UnsupportedAudioTypeError,
)
from controllers.console.app.wraps import get_app_model_with_trial
from controllers.console.explore.error import (
AppSuggestedQuestionsAfterAnswerDisabledError,
NotChatAppError,
NotCompletionAppError,
NotWorkflowAppError,
)
from controllers.console.explore.wraps import TrialAppResource, trial_feature_enable
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.errors.error import (
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.app_fields import app_detail_fields_with_site
from fields.dataset_fields import dataset_fields
from fields.workflow_fields import workflow_fields
from libs import helper
from libs.helper import uuid_value
from libs.login import current_user
from models import Account
from models.account import TenantStatus
from models.model import AppMode, Site
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.app_service import AppService
from services.audio_service import AudioService
from services.dataset_service import DatasetService
from services.errors.audio import (
AudioTooLargeServiceError,
NoAudioUploadedServiceError,
ProviderNotSupportSpeechToTextServiceError,
UnsupportedAudioTypeServiceError,
)
from services.errors.conversation import ConversationNotExistsError
from services.errors.llm import InvokeRateLimitError
from services.errors.message import (
MessageNotExistsError,
SuggestedQuestionsAfterAnswerDisabledError,
)
from services.message_service import MessageService
from services.recommended_app_service import RecommendedAppService
logger = logging.getLogger(__name__)
class TrialAppWorkflowRunApi(TrialAppResource):
def post(self, trial_app):
"""
Run workflow
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialAppWorkflowTaskStopApi(TrialAppResource):
def post(self, trial_app, task_id: str):
"""
Stop workflow task
"""
app_model = trial_app
if not app_model:
raise NotWorkflowAppError()
app_mode = AppMode.value_of(app_model.mode)
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
assert current_user is not None
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}
class TrialChatApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=True
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except InvokeRateLimitError as ex:
raise InvokeRateLimitHttpError(ex.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialMessageSuggestedQuestionApi(TrialAppResource):
@trial_feature_enable
def get(self, trial_app, message_id):
app_model = trial_app
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
message_id = str(message_id)
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model, user=current_user, message_id=message_id, invoke_from=InvokeFrom.EXPLORE
)
except MessageNotExistsError:
raise NotFound("Message not found")
except ConversationNotExistsError:
raise NotFound("Conversation not found")
except SuggestedQuestionsAfterAnswerDisabledError:
raise AppSuggestedQuestionsAfterAnswerDisabledError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
return {"data": questions}
class TrialChatAudioApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
file = request.files["file"]
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_asr(app_model=app_model, file=file, end_user=None)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialChatTextApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except NoAudioUploadedServiceError:
raise NoAudioUploadedError()
except AudioTooLargeServiceError as e:
raise AudioTooLargeError(str(e))
except UnsupportedAudioTypeServiceError:
raise UnsupportedAudioTypeError()
except ProviderNotSupportSpeechToTextServiceError:
raise ProviderNotSupportSpeechToTextError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception as e:
logger.exception("internal server error.")
raise InternalServerError()
class TrialCompletionApi(TrialAppResource):
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
try:
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
# Get IDs before they might be detached from session
app_id = app_model.id
user_id = current_user.id
response = AppGenerateService.generate(
app_model=app_model, user=current_user, args=args, invoke_from=InvokeFrom.EXPLORE, streaming=streaming
)
RecommendedAppService.add_trial_app_record(app_id, user_id)
return helper.compact_generate_response(response)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
except services.errors.conversation.ConversationCompletedError:
raise ConversationCompletedError()
except services.errors.app_model_config.AppModelConfigBrokenError:
logger.exception("App model config broken.")
raise AppUnavailableError()
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
class TrialSitApi(Resource):
"""Resource for trial app sites."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app site info.
Returns the site configuration for the application including theme, icons, and text.
"""
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
assert app_model.tenant
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return SiteResponse.model_validate(site).model_dump(mode="json")
class TrialAppParameterApi(Resource):
"""Resource for app variables."""
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
"""Retrieve app parameters."""
if app_model is None:
raise AppUnavailableError()
if app_model.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
workflow = app_model.workflow
if workflow is None:
raise AppUnavailableError()
features_dict = workflow.features_dict
user_input_form = workflow.user_input_form(to_old_structure=True)
else:
app_model_config = app_model.app_model_config
if app_model_config is None:
raise AppUnavailableError()
features_dict = app_model_config.to_dict()
user_input_form = features_dict.get("user_input_form", [])
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
return ParametersResponse.model_validate(parameters).model_dump(mode="json")
class AppApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(app_detail_fields_with_site)
def get(self, app_model):
"""Get app detail"""
app_service = AppService()
app_model = app_service.get_app(app_model)
return app_model
class AppWorkflowApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
@marshal_with(workflow_fields)
def get(self, app_model):
"""Get workflow detail"""
if not app_model.workflow_id:
raise AppUnavailableError()
workflow = (
db.session.query(Workflow)
.where(
Workflow.id == app_model.workflow_id,
)
.first()
)
return workflow
class DatasetListApi(Resource):
@trial_feature_enable
@get_app_model_with_trial
def get(self, app_model):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
ids = request.args.getlist("ids")
tenant_id = app_model.tenant_id
if ids:
datasets, total = DatasetService.get_datasets_by_ids(ids, tenant_id)
else:
raise NeedAddIdsError()
data = cast(list[dict[str, Any]], marshal(datasets, dataset_fields))
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response
api.add_resource(TrialChatApi, "/trial-apps/<uuid:app_id>/chat-messages", endpoint="trial_app_chat_completion")
api.add_resource(
TrialMessageSuggestedQuestionApi,
"/trial-apps/<uuid:app_id>/messages/<uuid:message_id>/suggested-questions",
endpoint="trial_app_suggested_question",
)
api.add_resource(TrialChatAudioApi, "/trial-apps/<uuid:app_id>/audio-to-text", endpoint="trial_app_audio")
api.add_resource(TrialChatTextApi, "/trial-apps/<uuid:app_id>/text-to-audio", endpoint="trial_app_text")
api.add_resource(TrialCompletionApi, "/trial-apps/<uuid:app_id>/completion-messages", endpoint="trial_app_completion")
api.add_resource(TrialSitApi, "/trial-apps/<uuid:app_id>/site")
api.add_resource(TrialAppParameterApi, "/trial-apps/<uuid:app_id>/parameters", endpoint="trial_app_parameters")
api.add_resource(AppApi, "/trial-apps/<uuid:app_id>", endpoint="trial_app")
api.add_resource(TrialAppWorkflowRunApi, "/trial-apps/<uuid:app_id>/workflows/run", endpoint="trial_app_workflow_run")
api.add_resource(TrialAppWorkflowTaskStopApi, "/trial-apps/<uuid:app_id>/workflows/tasks/<string:task_id>/stop")
api.add_resource(AppWorkflowApi, "/trial-apps/<uuid:app_id>/workflows", endpoint="trial_app_workflow")
api.add_resource(DatasetListApi, "/trial-apps/<uuid:app_id>/datasets", endpoint="trial_app_datasets")

View File

@ -2,14 +2,15 @@ from collections.abc import Callable
from functools import wraps
from typing import Concatenate, ParamSpec, TypeVar
from flask import abort
from flask_restx import Resource
from werkzeug.exceptions import NotFound
from controllers.console.explore.error import AppAccessDeniedError
from controllers.console.explore.error import AppAccessDeniedError, TrialAppLimitExceeded, TrialAppNotAllowed
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from models import AccountTrialAppRecord, App, InstalledApp, TrialApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -71,6 +72,61 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
return decorator
def trial_app_required(view: Callable[Concatenate[App, P], R] | None = None):
def decorator(view: Callable[Concatenate[App, P], R]):
@wraps(view)
def decorated(app_id: str, *args: P.args, **kwargs: P.kwargs):
current_user, _ = current_account_with_tenant()
trial_app = db.session.query(TrialApp).where(TrialApp.app_id == str(app_id)).first()
if trial_app is None:
raise TrialAppNotAllowed()
app = trial_app.app
if app is None:
raise TrialAppNotAllowed()
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.account_id == current_user.id, AccountTrialAppRecord.app_id == app_id)
.first()
)
if account_trial_app_record:
if account_trial_app_record.count >= trial_app.trial_limit:
raise TrialAppLimitExceeded()
return view(app, *args, **kwargs)
return decorated
if view:
return decorator(view)
return decorator
def trial_feature_enable(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_trial_app:
abort(403, "Trial app feature is not enabled.")
return view(*args, **kwargs)
return decorated
def explore_banner_enabled(view: Callable[..., R]) -> Callable[..., R]:
@wraps(view)
def decorated(*args, **kwargs):
features = FeatureService.get_system_features()
if not features.enable_explore_banner:
abort(403, "Explore banner feature is not enabled.")
return view(*args, **kwargs)
return decorated
class InstalledAppResource(Resource):
# must be reversed if there are multiple decorators
@ -80,3 +136,13 @@ class InstalledAppResource(Resource):
account_initialization_required,
login_required,
]
class TrialAppResource(Resource):
# must be reversed if there are multiple decorators
method_decorators = [
trial_app_required,
account_initialization_required,
login_required,
]

View File

@ -3,8 +3,8 @@ from datetime import UTC, datetime
from typing import Any, ClassVar
from pydantic import TypeAdapter
from sqlalchemy.orm import Session, sessionmaker
from core.db.session_factory import session_factory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events.base import GraphEngineEvent
from core.workflow.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent
@ -31,13 +31,11 @@ class TriggerPostLayer(GraphEngineLayer):
cfs_plan_scheduler_entity: AsyncWorkflowCFSPlanEntity,
start_time: datetime,
trigger_log_id: str,
session_maker: sessionmaker[Session],
):
super().__init__()
self.trigger_log_id = trigger_log_id
self.start_time = start_time
self.cfs_plan_scheduler_entity = cfs_plan_scheduler_entity
self.session_maker = session_maker
def on_graph_start(self):
pass
@ -47,7 +45,7 @@ class TriggerPostLayer(GraphEngineLayer):
Update trigger log with success or failure.
"""
if isinstance(event, tuple(self._STATUS_MAP.keys())):
with self.session_maker() as session:
with session_factory.create_session() as session:
repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_log = repo.get_by_id(self.trigger_log_id)
if not trigger_log:

View File

@ -35,7 +35,6 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from repositories.factory import DifyAPIRepositoryFactory
from tasks.ops_trace_task import process_trace_tasks
if TYPE_CHECKING:
@ -473,6 +472,9 @@ class TraceTask:
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
# Lazy import to avoid circular import during module initialization
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo

View File

@ -1,5 +1,6 @@
import contextlib
import json
import logging
from collections.abc import Generator, Iterable
from copy import deepcopy
from datetime import UTC, datetime
@ -36,6 +37,8 @@ from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
class ToolEngine:
"""
@ -123,25 +126,31 @@ class ToolEngine:
# transform tool invoke message to get LLM friendly message
return plain_text, message_files, meta
except ToolProviderCredentialValidationError as e:
logger.error(e, exc_info=True)
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError) as e:
error_response = f"there is not a tool named {tool.entity.identity.name}"
logger.error(e, exc_info=True)
agent_tool_callback.on_tool_error(e)
except ToolParameterValidationError as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
except ToolEngineInvokeError as e:
meta = e.meta
error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], meta
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
logger.error(e, exc_info=True)
return error_response, [], ToolInvokeMeta.error_instance(error_response)

View File

@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
)
from core.tools.errors import ToolInvokeError
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
from models.model import App, EndUser
from models.workflow import Workflow
@ -28,21 +27,6 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
def _try_resolve_user_from_request() -> Account | EndUser | None:
"""
Try to resolve user from Flask request context.
Returns None if not in a request context or if user is not available.
"""
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
# Use _get_current_object() to dereference the proxy
user = getattr(current_user, "_get_current_object", lambda: current_user)()
# Check if we got a valid user object
if user is not None and hasattr(user, "id"):
return user
return None
class WorkflowTool(Tool):
"""
Workflow tool.
@ -223,12 +207,6 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
# Try to resolve user from request context first
user = _try_resolve_user_from_request()
if user is not None:
return user
# Fall back to database resolution
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:

View File

@ -7,16 +7,28 @@ execution in multi-threaded environments.
from core.workflow.context.execution_context import (
AppContext,
ContextProviderNotFoundError,
ExecutionContext,
IExecutionContext,
NullAppContext,
capture_current_context,
read_context,
register_context,
register_context_capturer,
reset_context_provider,
)
from core.workflow.context.models import SandboxContext
__all__ = [
"AppContext",
"ContextProviderNotFoundError",
"ExecutionContext",
"IExecutionContext",
"NullAppContext",
"SandboxContext",
"capture_current_context",
"read_context",
"register_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@ -4,9 +4,11 @@ Execution Context - Abstracted context management for workflow execution.
import contextvars
from abc import ABC, abstractmethod
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
from typing import Any, Protocol, TypeVar, final, runtime_checkable
from pydantic import BaseModel
class AppContext(ABC):
@ -204,13 +206,75 @@ class ExecutionContextBuilder:
)
_capturer: Callable[[], IExecutionContext] | None = None
# Tenant-scoped providers using tuple keys for clarity and constant-time lookup.
# Key mapping:
# (name, tenant_id) -> provider
# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox")
# - tenant_id: tenant identifier string
# Value:
# provider: Callable[[], BaseModel] returning the typed context value
# Type-safety note:
# - This registry cannot enforce that all providers for a given name return the same BaseModel type.
# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice),
# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and
# def read_sandbox_ctx(tenant_id: str) -> SandboxContext.
_tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {}
T = TypeVar("T", bound=BaseModel)
class ContextProviderNotFoundError(KeyError):
"""Raised when a tenant-scoped context provider is missing for a given (name, tenant_id)."""
pass
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""Register a single enterable execution context capturer (e.g., Flask)."""
global _capturer
_capturer = capturer
def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None:
"""Register a tenant-specific provider for a named context.
Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions.
Consider adding a typed wrapper for this registration in your feature module.
"""
_tenant_context_providers[(name, tenant_id)] = provider
def read_context(name: str, *, tenant_id: str) -> BaseModel:
"""
Read a context value for a specific tenant.
Raises KeyError if the provider for (name, tenant_id) is not registered.
"""
prov = _tenant_context_providers.get((name, tenant_id))
if prov is None:
raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'")
return prov()
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
Returns:
IExecutionContext with captured context
If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal
context with NullAppContext + copy of current contextvars.
"""
from context import capture_current_context
if _capturer is None:
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
return capture_current_context()
def reset_context_provider() -> None:
"""Reset the capturer and all tenant-scoped context providers (primarily for tests)."""
global _capturer
_capturer = None
_tenant_context_providers.clear()

View File

@ -0,0 +1,13 @@
from __future__ import annotations
from pydantic import AnyHttpUrl, BaseModel
class SandboxContext(BaseModel):
"""Typed context for sandbox integration. All fields optional by design."""
sandbox_url: AnyHttpUrl | None = None
sandbox_token: str | None = None # optional, if later needed for auth
__all__ = ["SandboxContext"]

View File

@ -235,7 +235,18 @@ class AgentNode(Node[AgentNodeData]):
0,
):
value_param = param.get("value", {})
params[key] = value_param.get("value", "") if value_param is not None else None
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params

View File

@ -0,0 +1,60 @@
"""make message annotation question not nullable
Revision ID: 9e6fa5cbcd80
Revises: 03f8dcbc611e
Create Date: 2025-11-06 16:03:54.549378
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '9e6fa5cbcd80'
down_revision = '288345cd01d1'
branch_labels = None
depends_on = None
def upgrade():
bind = op.get_bind()
message_annotations = sa.table(
"message_annotations",
sa.column("id", sa.String),
sa.column("message_id", sa.String),
sa.column("question", sa.Text),
)
messages = sa.table(
"messages",
sa.column("id", sa.String),
sa.column("query", sa.Text),
)
update_question_from_message = (
sa.update(message_annotations)
.where(
sa.and_(
message_annotations.c.question.is_(None),
message_annotations.c.message_id.isnot(None),
)
)
.values(
question=sa.select(sa.func.coalesce(messages.c.query, ""))
.where(messages.c.id == message_annotations.c.message_id)
.scalar_subquery()
)
)
bind.execute(update_question_from_message)
fill_remaining_questions = (
sa.update(message_annotations)
.where(message_annotations.c.question.is_(None))
.values(question="")
)
bind.execute(fill_remaining_questions)
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=False)
def downgrade():
with op.batch_alter_table('message_annotations', schema=None) as batch_op:
batch_op.alter_column('question', existing_type=sa.TEXT(), nullable=True)

View File

@ -0,0 +1,73 @@
"""add table explore banner and trial
Revision ID: f9f6d18a37f9
Revises: 9e6fa5cbcd80
Create Date: 2026-01-017 11:10:18.079355
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f9f6d18a37f9'
down_revision = '9e6fa5cbcd80'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='user_trial_app_pkey'),
sa.UniqueConstraint('account_id', 'app_id', name='unique_account_trial_app_record')
)
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.create_index('account_trial_app_record_account_id_idx', ['account_id'], unique=False)
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('trial_limit', sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint('id', name='trial_app_pkey'),
sa.UniqueConstraint('app_id', name='unique_trail_app_id')
)
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.create_index('trial_app_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('trial_app_tenant_id_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('trial_apps', schema=None) as batch_op:
batch_op.drop_index('trial_app_tenant_id_idx')
batch_op.drop_index('trial_app_app_id_idx')
op.drop_table('trial_apps')
op.drop_table('exporle_banners')
with op.batch_alter_table('account_trial_app_records', schema=None) as batch_op:
batch_op.drop_index('account_trial_app_record_app_id_idx')
batch_op.drop_index('account_trial_app_record_account_id_idx')
op.drop_table('account_trial_app_records')
# ### end Alembic commands ###

View File

@ -35,6 +35,7 @@ from .enums import (
WorkflowTriggerStatus,
)
from .model import (
AccountTrialAppRecord,
ApiRequest,
ApiToken,
App,
@ -47,6 +48,7 @@ from .model import (
DatasetRetrieverResource,
DifySetup,
EndUser,
ExporleBanner,
IconType,
InstalledApp,
Message,
@ -62,6 +64,7 @@ from .model import (
TagBinding,
TenantCreditPool,
TraceAppConfig,
TrialApp,
UploadFile,
)
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
@ -114,6 +117,7 @@ __all__ = [
"Account",
"AccountIntegrate",
"AccountStatus",
"AccountTrialAppRecord",
"ApiRequest",
"ApiToken",
"ApiToolProvider",
@ -150,6 +154,7 @@ __all__ = [
"DocumentSegment",
"Embedding",
"EndUser",
"ExporleBanner",
"ExternalKnowledgeApis",
"ExternalKnowledgeBindings",
"IconType",
@ -188,6 +193,7 @@ __all__ = [
"ToolLabelBinding",
"ToolModelInvoke",
"TraceAppConfig",
"TrialApp",
"TriggerOAuthSystemClient",
"TriggerOAuthTenantClient",
"TriggerSubscription",

View File

@ -603,6 +603,64 @@ class InstalledApp(TypeBase):
return tenant
class TrialApp(Base):
__tablename__ = "trial_apps"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="trial_app_pkey"),
sa.Index("trial_app_app_id_idx", "app_id"),
sa.Index("trial_app_tenant_id_idx", "tenant_id"),
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
trial_limit = mapped_column(sa.Integer, nullable=False, default=3)
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
class AccountTrialAppRecord(Base):
__tablename__ = "account_trial_app_records"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="user_trial_app_pkey"),
sa.Index("account_trial_app_record_account_id_idx", "account_id"),
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self) -> App | None:
app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def user(self) -> Account | None:
user = db.session.query(Account).where(Account.id == self.account_id).first()
return user
class ExporleBanner(Base):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
content = mapped_column(sa.JSON, nullable=False)
link = mapped_column(String(255), nullable=False)
sort = mapped_column(sa.Integer, nullable=False)
status = mapped_column(sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"))
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
language = mapped_column(String(255), nullable=False, server_default=sa.text("'en-US'::character varying"))
class OAuthProviderApp(TypeBase):
"""
Globally shared OAuth provider app information.
@ -1423,7 +1481,7 @@ class MessageAnnotation(Base):
app_id: Mapped[str] = mapped_column(StringUUID)
conversation_id: Mapped[str | None] = mapped_column(StringUUID, sa.ForeignKey("conversations.id"))
message_id: Mapped[str | None] = mapped_column(StringUUID)
question: Mapped[str | None] = mapped_column(LongText, nullable=True)
question: Mapped[str] = mapped_column(LongText, nullable=False)
content: Mapped[str] = mapped_column(LongText, nullable=False)
hit_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
account_id: Mapped[str] = mapped_column(StringUUID, nullable=False)

View File

@ -2,9 +2,11 @@ import logging
import time
import click
from redis.exceptions import LockError
import app
from configs import dify_config
from extensions.ext_redis import redis_client
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
@ -31,12 +33,16 @@ def clean_messages():
)
# Create and run the cleanup service
service = MessagesCleanService.from_days(
policy=policy,
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
stats = service.run()
# lock the task to avoid concurrent execution in case of the future data volume growth
with redis_client.lock(
"retention:clean_messages", timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL, blocking=False
):
service = MessagesCleanService.from_days(
policy=policy,
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
@ -50,6 +56,16 @@ def clean_messages():
fg="green",
)
)
except LockError:
end_at = time.perf_counter()
logger.exception("clean_messages: acquire task lock failed, skip current execution")
click.echo(
click.style(
f"clean_messages: skipped (lock already held) - latency: {end_at - start_at:.2f}s",
fg="yellow",
)
)
raise
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")

View File

@ -1,11 +1,16 @@
import logging
from datetime import UTC, datetime
import click
from redis.exceptions import LockError
import app
from configs import dify_config
from extensions.ext_redis import redis_client
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
logger = logging.getLogger(__name__)
@app.celery.task(queue="retention")
def clean_workflow_runs_task() -> None:
@ -25,19 +30,50 @@ def clean_workflow_runs_task() -> None:
start_time = datetime.now(UTC)
WorkflowRunCleanup(
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
start_from=None,
end_before=None,
).run()
try:
# lock the task to avoid concurrent execution in case of the future data volume growth
with redis_client.lock(
"retention:clean_workflow_runs_task",
timeout=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL,
blocking=False,
):
WorkflowRunCleanup(
days=dify_config.SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS,
batch_size=dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE,
start_from=None,
end_before=None,
).run()
end_time = datetime.now(UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
end_time = datetime.now(UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Scheduled workflow run cleanup finished. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)
)
except LockError:
end_time = datetime.now(UTC)
elapsed = end_time - start_time
logger.exception("clean_workflow_runs_task: acquire task lock failed, skip current execution")
click.echo(
click.style(
f"Scheduled workflow run cleanup skipped (lock already held). "
f"start={start_time.isoformat()} end={end_time.isoformat()} duration={elapsed}",
fg="yellow",
)
)
raise
except Exception as e:
end_time = datetime.now(UTC)
elapsed = end_time - start_time
logger.exception("clean_workflow_runs_task failed")
click.echo(
click.style(
f"Scheduled workflow run cleanup failed. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed} - {str(e)}",
fg="red",
)
)
raise

View File

@ -209,8 +209,12 @@ class AppAnnotationService:
if not app:
raise NotFound("App not found")
question = args.get("question")
if question is None:
raise ValueError("'question' is required")
annotation = MessageAnnotation(
app_id=app.id, content=args["answer"], question=args["question"], account_id=current_user.id
app_id=app.id, content=args["answer"], question=question, account_id=current_user.id
)
db.session.add(annotation)
db.session.commit()
@ -219,7 +223,7 @@ class AppAnnotationService:
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
args["question"],
question,
current_tenant_id,
app_id,
annotation_setting.collection_binding_id,
@ -244,8 +248,12 @@ class AppAnnotationService:
if not annotation:
raise NotFound("Annotation not found")
question = args.get("question")
if question is None:
raise ValueError("'question' is required")
annotation.content = args["answer"]
annotation.question = args["question"]
annotation.question = question
db.session.commit()
# if annotation reply is enabled , add annotation to index

View File

@ -172,6 +172,8 @@ class SystemFeatureModel(BaseModel):
enable_change_email: bool = True
plugin_manager: PluginManagerModel = PluginManagerModel()
trial_models: list[str] = []
enable_trial_app: bool = False
enable_explore_banner: bool = False
class FeatureService:
@ -228,6 +230,8 @@ class FeatureService:
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
system_features.trial_models = cls._fulfill_trial_models_from_env()
system_features.enable_trial_app = dify_config.ENABLE_TRIAL_APP
system_features.enable_explore_banner = dify_config.ENABLE_EXPLORE_BANNER
@classmethod
def _fulfill_trial_models_from_env(cls) -> list[str]:
@ -240,6 +244,7 @@ class FeatureService:
)
]
@classmethod
def _fulfill_params_from_env(cls, features: FeatureModel):
features.can_replace_logo = dify_config.CAN_REPLACE_LOGO

View File

@ -1,4 +1,7 @@
from configs import dify_config
from extensions.ext_database import db
from models.model import AccountTrialAppRecord, TrialApp
from services.feature_service import FeatureService
from services.recommend_app.recommend_app_factory import RecommendAppRetrievalFactory
@ -20,6 +23,15 @@ class RecommendedAppService:
)
)
if FeatureService.get_system_features().enable_trial_app:
apps = result["recommended_apps"]
for app in apps:
app_id = app["app_id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
app["can_trial"] = True
else:
app["can_trial"] = False
return result
@classmethod
@ -32,4 +44,30 @@ class RecommendedAppService:
mode = dify_config.HOSTED_FETCH_APP_TEMPLATES_MODE
retrieval_instance = RecommendAppRetrievalFactory.get_recommend_app_factory(mode)()
result: dict = retrieval_instance.get_recommend_app_detail(app_id)
if FeatureService.get_system_features().enable_trial_app:
app_id = result["id"]
trial_app_model = db.session.query(TrialApp).where(TrialApp.app_id == app_id).first()
if trial_app_model:
result["can_trial"] = True
else:
result["can_trial"] = False
return result
@classmethod
def add_trial_app_record(cls, app_id: str, account_id: str):
"""
Add trial app record.
:param app_id: app id
:return:
"""
account_trial_app_record = (
db.session.query(AccountTrialAppRecord)
.where(AccountTrialAppRecord.app_id == app_id, AccountTrialAppRecord.account_id == account_id)
.first()
)
if account_trial_app_record:
account_trial_app_record.count += 1
db.session.commit()
else:
db.session.add(AccountTrialAppRecord(app_id=app_id, count=1, account_id=account_id))
db.session.commit()

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DatasetAutoDisableLog, DocumentSegment
@ -28,106 +28,106 @@ def add_document_to_index_task(dataset_document_id: str):
logger.info(click.style(f"Start add document to index: {dataset_document_id}", fg="green"))
start_at = time.perf_counter()
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logger.info(click.style(f"Document not found: {dataset_document_id}", fg="red"))
return
if dataset_document.indexing_status != "completed":
db.session.close()
return
if dataset_document.indexing_status != "completed":
return
indexing_cache_key = f"document_{dataset_document.id}_indexing"
indexing_cache_key = f"document_{dataset_document.id}_indexing"
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
try:
dataset = dataset_document.dataset
if not dataset:
raise Exception(f"Document {dataset_document.id} dataset {dataset_document.dataset_id} doesn't exist.")
segments = (
db.session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
)
.order_by(DocumentSegment.position.asc())
.all()
)
.order_by(DocumentSegment.position.asc())
.all()
)
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
session.query(DatasetAutoDisableLog).where(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
# update segment to enable
session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
session.commit()
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
# update segment to enable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
DocumentSegment.disabled_by: None,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
)
except Exception as e:
logger.exception("add document to index failed")
dataset_document.enabled = False
dataset_document.disabled_at = naive_utc_now()
dataset_document.indexing_status = "error"
dataset_document.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,9 +5,9 @@ import click
from celery import shared_task
from werkzeug.exceptions import NotFound
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -32,74 +32,72 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
indexing_cache_key = f"app_annotation_batch_import_{str(job_id)}"
active_jobs_key = f"annotation_import_active:{tenant_id}"
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
with session_factory.create_session() as session:
# get app info
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
if app:
try:
documents = []
for content in content_list:
annotation = MessageAnnotation(
app_id=app.id, content=content["answer"], question=content["question"], account_id=user_id
)
session.add(annotation)
session.flush()
document = Document(
page_content=content["question"],
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
db.session.add(annotation)
db.session.flush()
document = Document(
page_content=content["question"],
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id, "annotation"
)
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
if app_annotation_setting:
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
app_annotation_setting.collection_binding_id, "annotation"
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create(documents, duplicate_check=True)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
),
fg="green",
)
)
if not dataset_collection_binding:
raise NotFound("App annotation setting not found")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=dataset_collection_binding.provider_name,
embedding_model=dataset_collection_binding.model_name,
collection_binding_id=dataset_collection_binding.id,
)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.create(documents, duplicate_check=True)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
"Build index successful for batch import annotation: {} latency: {}".format(
job_id, end_at - start_at
),
fg="green",
)
)
except Exception as e:
db.session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
# Clean up active job tracking to release concurrency slot
try:
redis_client.zrem(active_jobs_key, job_id)
logger.debug("Released concurrency slot for job: %s", job_id)
except Exception as cleanup_error:
# Log but don't fail if cleanup fails - the job will be auto-expired
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)
# Close database session
db.session.close()
except Exception as e:
session.rollback()
redis_client.setex(indexing_cache_key, 600, "error")
indexing_error_msg_key = f"app_annotation_batch_import_error_msg_{str(job_id)}"
redis_client.setex(indexing_error_msg_key, 600, str(e))
logger.exception("Build index for batch import annotations failed")
finally:
# Clean up active job tracking to release concurrency slot
try:
redis_client.zrem(active_jobs_key, job_id)
logger.debug("Released concurrency slot for job: %s", job_id)
except Exception as cleanup_error:
# Log but don't fail if cleanup fails - the job will be auto-expired
logger.warning("Failed to clean up active job tracking for %s: %s", job_id, cleanup_error)

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import exists, select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset
from models.model import App, AppAnnotationSetting, MessageAnnotation
@ -22,50 +22,55 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logger.info(click.style(f"Start delete app annotations index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = db.session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
annotations_exists = session.scalar(select(exists().where(MessageAnnotation.app_id == app_id)))
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
return
app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if not app_annotation_setting:
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
db.session.close()
return
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=app_annotation_setting.collection_binding_id,
app_annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if not app_annotation_setting:
logger.info(click.style(f"App annotation setting not found: {app_id}", fg="red"))
return
disable_app_annotation_key = f"disable_app_annotation_{str(app_id)}"
disable_app_annotation_job_key = f"disable_app_annotation_job_{str(job_id)}"
try:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
collection_binding_id=app_annotation_setting.collection_binding_id,
)
# delete annotation setting
db.session.delete(app_annotation_setting)
db.session.commit()
try:
if annotations_exists:
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
vector.delete()
except Exception:
logger.exception("Delete annotation index failed when annotation deleted.")
redis_client.setex(disable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)
db.session.close()
# delete annotation setting
session.delete(app_annotation_setting)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"App annotations index deleted : {app_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Annotation batch deleted index failed")
redis_client.setex(disable_app_annotation_job_key, 600, "error")
disable_app_annotation_error_key = f"disable_app_annotation_error_{str(job_id)}"
redis_client.setex(disable_app_annotation_error_key, 600, str(e))
finally:
redis_client.delete(disable_app_annotation_key)

View File

@ -5,9 +5,9 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset
@ -33,92 +33,98 @@ def enable_annotation_reply_task(
logger.info(click.style(f"Start add app annotation to index: {app_id}", fg="green"))
start_at = time.perf_counter()
# get app info
app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
with session_factory.create_session() as session:
app = session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
db.session.close()
return
if not app:
logger.info(click.style(f"App not found: {app_id}", fg="red"))
return
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
annotations = session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = naive_utc_now()
db.session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id,
try:
documents = []
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
db.session.add(new_app_annotation_setting)
annotation_setting = (
session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type(
annotation_setting.collection_binding_id, "annotation"
)
)
if old_dataset_collection_binding and annotations:
old_dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=old_dataset_collection_binding.provider_name,
embedding_model=old_dataset_collection_binding.model_name,
collection_binding_id=old_dataset_collection_binding.id,
)
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
old_vector = Vector(old_dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
old_vector.delete()
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
annotation_setting.score_threshold = score_threshold
annotation_setting.collection_binding_id = dataset_collection_binding.id
annotation_setting.updated_user_id = user_id
annotation_setting.updated_at = naive_utc_now()
session.add(annotation_setting)
else:
new_app_annotation_setting = AppAnnotationSetting(
app_id=app_id,
score_threshold=score_threshold,
collection_binding_id=dataset_collection_binding.id,
created_user_id=user_id,
updated_user_id=user_id,
)
documents.append(document)
session.add(new_app_annotation_setting)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
db.session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations added to index: {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
db.session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)
db.session.close()
dataset = Dataset(
id=app_id,
tenant_id=tenant_id,
indexing_technique="high_quality",
embedding_model_provider=embedding_provider_name,
embedding_model=embedding_model_name,
collection_binding_id=dataset_collection_binding.id,
)
if annotations:
for annotation in annotations:
document = Document(
page_content=annotation.question_text,
metadata={"annotation_id": annotation.id, "app_id": app_id, "doc_id": annotation.id},
)
documents.append(document)
vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
try:
vector.delete_by_metadata_field("app_id", app_id)
except Exception as e:
logger.info(click.style(f"Delete annotation index error: {str(e)}", fg="red"))
vector.create(documents)
session.commit()
redis_client.setex(enable_app_annotation_job_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"App annotations added to index: {app_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception as e:
logger.exception("Annotation batch created index failed")
redis_client.setex(enable_app_annotation_job_key, 600, "error")
enable_app_annotation_error_key = f"enable_app_annotation_error_{str(job_id)}"
redis_client.setex(enable_app_annotation_error_key, 600, str(e))
session.rollback()
finally:
redis_client.delete(enable_app_annotation_key)

View File

@ -10,13 +10,13 @@ from typing import Any
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.orm import Session
from configs import dify_config
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY, WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.trigger_post_layer import TriggerPostLayer
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models.account import Account
from models.enums import CreatorUserRole, WorkflowTriggerStatus
from models.model import App, EndUser, Tenant
@ -98,10 +98,7 @@ def _execute_workflow_common(
):
"""Execute workflow with common logic and trigger log updates."""
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
trigger_log_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
# Get trigger log
@ -157,7 +154,7 @@ def _execute_workflow_common(
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[
# TODO: Re-enable TimeSliceLayer after the HITL release.
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id, session_factory),
TriggerPostLayer(cfs_plan_scheduler_entity, start_time, trigger_log.id),
],
)

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment
from models.model import UploadFile
@ -28,65 +28,64 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
"""
logger.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
if not doc_form:
raise ValueError("doc_form is required")
try:
if not doc_form:
raise ValueError("doc_form is required")
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
storage.delete(file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
)
db.session.delete(image_file)
db.session.delete(segment)
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
db.session.commit()
if file_ids:
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
db.session.delete(file)
session.commit()
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned documents when documents deleted failed")

View File

@ -9,9 +9,9 @@ import pandas as pd
from celery import shared_task
from sqlalchemy import func
from core.db.session_factory import session_factory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from libs import helper
@ -48,104 +48,107 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
try:
dataset = db.session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
if not dataset:
raise ValueError("Dataset not exist.")
dataset_document = db.session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
dataset_document = session.get(Document, document_id)
if not dataset_document:
raise ValueError("Document not exist.")
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
raise ValueError("Document is not available.")
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
raise ValueError("Document is not available.")
upload_file = db.session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
upload_file = session.get(UploadFile, upload_file_id)
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
else:
tokens_list = [0] * len(content)
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
db.session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
db.session.add(segment_document)
document_segments.append(segment_document)
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
db.session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
finally:
db.session.close()
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import WorkflowType
from models.dataset import (
@ -53,135 +53,155 @@ def clean_dataset_task(
logger.info(click.style(f"Start clean dataset when dataset deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(SegmentAttachmentBinding.tenant_id == tenant_id, SegmentAttachmentBinding.dataset_id == dataset_id)
).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexStructureType
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(f"Invalid doc_form detected, using default index type for cleanup: {doc_form}", fg="yellow")
)
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
with session_factory.create_session() as session:
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
dataset = Dataset(
id=dataset_id,
tenant_id=tenant_id,
indexing_technique=indexing_technique,
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
)
).all()
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
if doc_form is None or (isinstance(doc_form, str) and not doc_form.strip()):
# Use default paragraph index type for empty/invalid datasets to enable vector database cleanup
from core.rag.index_processor.constant.index_type import IndexStructureType
for document in documents:
db.session.delete(document)
# delete document file
doc_form = IndexStructureType.PARAGRAPH_INDEX
logger.info(
click.style(
f"Invalid doc_form detected, using default index type for cleanup: {doc_form}",
fg="yellow",
)
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
# Add exception handling around IndexProcessorFactory.clean() to prevent single point of failure
# This ensures Document/Segment deletion can continue even if vector database cleanup fails
try:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
click.style(f"Continuing with document and segment deletion for dataset: {dataset_id}", fg="yellow")
)
if documents is None or len(documents) == 0:
logger.info(click.style(f"No documents found for dataset: {dataset_id}", fg="green"))
else:
logger.info(click.style(f"Cleaning documents for dataset: {dataset_id}", fg="green"))
for document in documents:
session.delete(document)
segment_ids = [segment.id for segment in segments]
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(image_file.key)
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(image_file)
db.session.delete(segment)
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete pipeline and workflow
if pipeline_id:
db.session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
db.session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
# delete files
if documents:
for document in documents:
try:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete pipeline and workflow
if pipeline_id:
session.query(Pipeline).where(Pipeline.id == pipeline_id).delete()
session.query(Workflow).where(
Workflow.tenant_id == tenant_id,
Workflow.app_id == pipeline_id,
Workflow.type == WorkflowType.RAG_PIPELINE,
).delete()
# delete files
if documents:
file_ids = []
for document in documents:
if document.data_source_type == "upload_file":
if document.data_source_info:
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
if not file:
continue
storage.delete(file.key)
db.session.delete(file)
except Exception:
continue
file_ids.append(file_id)
files = session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
for file in files:
storage.delete(file.key)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}", fg="green")
)
except Exception:
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
db.session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(file_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned dataset when dataset deleted: {dataset_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Failed to rollback database session")
# Add rollback to prevent dirty session state in case of exceptions
# This ensures the database session is properly cleaned up
try:
session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception:
logger.exception("Failed to rollback database session")
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
db.session.close()
logger.exception("Cleaned dataset when dataset deleted failed")
finally:
# Explicitly close the session for test expectations and safety
try:
session.close()
except Exception:
logger.exception("Failed to close database session")

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DatasetMetadataBinding, DocumentSegment, SegmentAttachmentBinding
from models.model import UploadFile
@ -29,85 +29,94 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = db.session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# Use JOIN to fetch attachments with bindings in a single query
attachments_with_bindings = session.execute(
select(SegmentAttachmentBinding, UploadFile)
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
.where(
SegmentAttachmentBinding.tenant_id == dataset.tenant_id,
SegmentAttachmentBinding.dataset_id == dataset_id,
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(image_file.key)
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
upload_file_id,
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(image_file)
db.session.delete(segment)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
db.session.commit()
if file_id:
file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
db.session.delete(file)
db.session.commit()
# delete segment attachments
if attachments_with_bindings:
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
db.session.delete(attachment_file)
db.session.delete(binding)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
db.session.commit()
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when document deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned document when document deleted failed")

View File

@ -3,10 +3,10 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment
logger = logging.getLogger(__name__)
@ -24,37 +24,37 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
for document_id in document_ids:
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
finally:
db.session.close()
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")

View File

@ -4,9 +4,9 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@ -25,75 +25,77 @@ def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = N
logger.info(click.style(f"Start create segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "waiting":
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
# update segment status to indexing
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
db.session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "waiting":
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, [document])
try:
# update segment status to indexing
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: naive_utc_now(),
}
)
session.commit()
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# update segment to completed
db.session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
db.session.commit()
dataset = segment.dataset
end_at = time.perf_counter()
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, [document])
# update segment to completed
session.query(DocumentSegment).filter_by(id=segment.id).update(
{
DocumentSegment.status: "completed",
DocumentSegment.completed_at: naive_utc_now(),
}
)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment created to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("create segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task # type: ignore
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -24,166 +24,174 @@ def deal_dataset_index_update_task(dataset_id: str, action: str):
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
.all()
)
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
.order_by(DocumentSegment.position.asc())
.all()
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
elif action == "update":
dataset_documents = (
session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style(
"Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Deal dataset vector index failed")

View File

@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@ -27,160 +27,170 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
logger.info(click.style(f"Start deal dataset vector index: {dataset_id}", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexStructureType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
.order_by(DocumentSegment.position.asc())
.all()
)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
end_at = time.perf_counter()
logger.info(click.style(f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Deal dataset vector index failed")
finally:
db.session.close()
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
elif action == "update":
dataset_documents = session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
).all()
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
session.query(DocumentSegment)
.where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == True,
)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(
dataset, documents, multimodal_documents=multimodal_documents, with_keywords=False
)
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
session.commit()
except Exception as e:
session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(
click.style(
f"Deal dataset vector index: {dataset_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Deal dataset vector index failed")

View File

@ -3,7 +3,7 @@ import logging
from celery import shared_task
from configs import dify_config
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models import Account
from services.billing_service import BillingService
from tasks.mail_account_deletion_task import send_deletion_success_task
@ -13,16 +13,17 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
account = db.session.query(Account).where(Account.id == account_id).first()
try:
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
with session_factory.create_session() as session:
account = session.query(Account).where(Account.id == account_id).first()
try:
if dify_config.BILLING_ENABLED:
BillingService.delete_account(account_id)
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise
if not account:
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)
if not account:
logger.error("Account %s not found.", account_id)
return
# send success email
send_deletion_success_task.delay(account.email)

View File

@ -4,7 +4,7 @@ import time
import click
from celery import shared_task
from extensions.ext_database import db
from core.db.session_factory import session_factory
from models import ConversationVariable
from models.model import Message, MessageAnnotation, MessageFeedback
from models.tools import ToolConversationVariables, ToolFile
@ -27,44 +27,46 @@ def delete_conversation_related_data(conversation_id: str):
)
start_at = time.perf_counter()
try:
db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
db.session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Succeeded cleaning data from db for conversation_id {conversation_id} latency: {end_at - start_at}",
fg="green",
with session_factory.create_session() as session:
try:
session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == conversation_id).delete(
synchronize_session=False
)
)
except Exception as e:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
db.session.rollback()
raise e
finally:
db.session.close()
session.query(MessageFeedback).where(MessageFeedback.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.query(ToolConversationVariables).where(
ToolConversationVariables.conversation_id == conversation_id
).delete(synchronize_session=False)
session.query(ToolFile).where(ToolFile.conversation_id == conversation_id).delete(synchronize_session=False)
session.query(ConversationVariable).where(ConversationVariable.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.query(Message).where(Message.conversation_id == conversation_id).delete(synchronize_session=False)
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
(
f"Succeeded cleaning data from db for conversation_id {conversation_id} "
f"latency: {end_at - start_at}"
),
fg="green",
)
)
except Exception:
logger.exception("Failed to delete data from db for conversation_id: %s failed", conversation_id)
session.rollback()
raise

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from models.dataset import Dataset, Document, SegmentAttachmentBinding
from models.model import UploadFile
@ -26,49 +26,52 @@ def delete_segment_from_index_task(
"""
logger.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
dataset_document = session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
db.session.delete(binding)
# delete upload file
db.session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
db.session.commit()
if dataset.is_multimodal:
# delete segment attachment binding
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("delete segment from index failed")
finally:
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("delete segment from index failed")

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DocumentSegment
@ -23,46 +23,53 @@ def disable_segment_from_index_task(segment_id: str):
logger.info(click.style(f"Start disable segment from index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, disable is not allowed: {segment_id}", fg="red"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
try:
dataset = segment.dataset
end_at = time.perf_counter()
logger.info(click.style(f"Segment removed from index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove segment from index failed")
segment.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [segment.index_node_id])
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment removed from index: {segment.id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("remove segment from index failed")
segment.enabled = True
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment, SegmentAttachmentBinding
from models.dataset import Document as DatasetDocument
@ -26,69 +26,65 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
db.session.close()
return
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
db.session.close()
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
db.session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
).all()
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
if dataset.is_multimodal:
segment_ids = [segment.id for segment in segments]
segment_attachment_bindings = (
session.query(SegmentAttachmentBinding)
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
.all()
)
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_node_ids.extend(attachment_ids)
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logger.info(click.style(f"Segments removed from index latency: {end_at - start_at}", fg="green"))
except Exception:
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)

View File

@ -3,12 +3,12 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.datasource_provider_service import DatasourceProviderService
@ -28,105 +28,103 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
db.session.commit()
db.session.close()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
last_edited_time = loader.get_notion_last_edited_time()
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
finally:
db.session.close()
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)

View File

@ -6,11 +6,11 @@ import click
from celery import shared_task
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@ -46,66 +46,63 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset is not found: {dataset_id}", fg="yellow"))
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
except Exception as e:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
return
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@ -3,8 +3,9 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@ -26,56 +27,54 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@ -4,15 +4,15 @@ from collections.abc import Callable, Sequence
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from configs import dify_config
from core.db.session_factory import session_factory
from core.entities.document_task import DocumentTask
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.pipeline.queue import TenantIsolatedTaskQueue
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
@ -76,63 +76,64 @@ def _duplicate_document_indexing_task_with_tenant_queue(
def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[str]):
documents = []
documents: list[Document] = []
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
with session_factory.create_session() as session:
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == CloudPlan.SANDBOX and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document in documents:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
documents = list(
session.scalars(
select(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
).all()
)
if document:
for document in documents:
logger.info(click.style(f"Start process document: {document.id}", fg="green"))
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document.id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@ -140,26 +141,24 @@ def _duplicate_document_indexing_task(dataset_id: str, document_ids: Sequence[st
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
session.add(document)
session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
finally:
db.session.close()
indexing_runner = IndexingRunner()
indexing_runner.run(list(documents))
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("duplicate_document_indexing_task failed, dataset_id: %s", dataset_id)
@shared_task(queue="dataset")

View File

@ -4,11 +4,11 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import DocumentSegment
@ -27,91 +27,93 @@ def enable_segment_to_index_task(segment_id: str):
logger.info(click.style(f"Start enable segment to index: {segment_id}", fg="green"))
start_at = time.perf_counter()
segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
db.session.close()
return
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
db.session.close()
return
indexing_cache_key = f"segment_{segment.id}_indexing"
try:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
with session_factory.create_session() as session:
segment = session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logger.info(click.style(f"Segment not found: {segment_id}", fg="red"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
if segment.status != "completed":
logger.info(click.style(f"Segment is not completed, enable is not allowed: {segment_id}", fg="red"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
indexing_cache_key = f"segment_{segment.id}_indexing"
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
try:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
dataset = segment.dataset
if not dataset:
logger.info(click.style(f"Segment {segment.id} has no dataset, pass.", fg="cyan"))
return
dataset_document = segment.document
if not dataset_document:
logger.info(click.style(f"Segment {segment.id} has no document, pass.", fg="cyan"))
return
if (
not dataset_document.enabled
or dataset_document.archived
or dataset_document.indexing_status != "completed"
):
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
return
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
child_documents.append(child_document)
document.children = child_documents
multimodel_documents = []
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodel_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
# save vector index
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segment to index failed")
segment.enabled = False
segment.disabled_at = naive_utc_now()
segment.status = "error"
segment.error = str(e)
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -5,11 +5,11 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, DocumentSegment
@ -29,105 +29,102 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset {dataset_id} not found, pass.", fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
dataset_document = session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
db.session.close()
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
db.session.close()
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
if not dataset_document:
logger.info(click.style(f"Document {document_id} not found, pass.", fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logger.info(click.style(f"Document {document_id} status is invalid, pass.", fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()
return
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
segments = session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
return
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
try:
documents = []
multimodal_documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_id,
"dataset_id": dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
if dataset.is_multimodal:
for attachment in segment.attachments:
multimodal_documents.append(
AttachmentDocument(
page_content=attachment["name"],
metadata={
"doc_id": attachment["id"],
"doc_hash": "",
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
"doc_type": DocType.IMAGE,
},
)
)
documents.append(document)
# save vector index
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception("enable segments to index failed")
# update segment error msg
session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"error": str(e),
"status": "error",
"disabled_at": naive_utc_now(),
"enabled": False,
}
)
session.commit()
finally:
for segment in segments:
indexing_cache_key = f"segment_{segment.id}_indexing"
redis_client.delete(indexing_cache_key)

View File

@ -4,8 +4,8 @@ import time
import click
from celery import shared_task
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
from models.dataset import Document
logger = logging.getLogger(__name__)
@ -23,26 +23,24 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Recover document: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
try:
indexing_runner = IndexingRunner()
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
indexing_runner.run([document])
elif document.indexing_status == "splitting":
indexing_runner.run_in_splitting_status(document)
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)
finally:
db.session.close()
try:
indexing_runner = IndexingRunner()
if document.indexing_status in {"waiting", "parsing", "cleaning"}:
indexing_runner.run([document])
elif document.indexing_status == "splitting":
indexing_runner.run_in_splitting_status(document)
elif document.indexing_status == "indexing":
indexing_runner.run_in_indexing_status(document)
end_at = time.perf_counter()
logger.info(click.style(f"Processed document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("recover_document_indexing_task failed, document_id: %s", document_id)

View File

@ -1,14 +1,17 @@
import logging
import time
from collections.abc import Callable
from typing import Any, cast
import click
import sqlalchemy as sa
from celery import shared_task
from sqlalchemy import delete
from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from extensions.ext_database import db
from models import (
ApiToken,
@ -77,7 +80,6 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
_delete_workflow_webhook_triggers(tenant_id, app_id)
_delete_workflow_schedule_plans(tenant_id, app_id)
_delete_workflow_trigger_logs(tenant_id, app_id)
end_at = time.perf_counter()
logger.info(click.style(f"App and related data deleted: {app_id} latency: {end_at - start_at}", fg="green"))
except SQLAlchemyError as e:
@ -89,8 +91,8 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(model_config_id: str):
db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
def del_model_config(session, model_config_id: str):
session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@ -101,8 +103,8 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
def del_site(session, site_id: str):
session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records(
"""select id from sites where app_id=:app_id limit 1000""",
@ -113,8 +115,8 @@ def _delete_app_site(tenant_id: str, app_id: str):
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
def del_mcp_server(session, mcp_server_id: str):
session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@ -125,8 +127,8 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
def del_api_token(session, api_token_id: str):
session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""",
@ -137,8 +139,8 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(installed_app_id: str):
db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
def del_installed_app(session, installed_app_id: str):
session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -149,10 +151,8 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(recommended_app_id: str):
db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
synchronize_session=False
)
def del_recommended_app(session, recommended_app_id: str):
session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from recommended_apps where app_id=:app_id limit 1000""",
@ -163,8 +163,8 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str):
db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
def del_annotation_hit_history(session, annotation_hit_history_id: str):
session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
synchronize_session=False
)
@ -175,8 +175,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
"annotation hit history",
)
def del_annotation_setting(annotation_setting_id: str):
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
def del_annotation_setting(session, annotation_setting_id: str):
session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@ -189,8 +189,8 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(dataset_join_id: str):
db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
def del_dataset_join(session, dataset_join_id: str):
session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@ -201,8 +201,8 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(workflow_id: str):
db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
def del_workflow(session, workflow_id: str):
session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -241,10 +241,8 @@ def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
synchronize_session=False
)
def del_workflow_app_log(session, workflow_app_log_id: str):
session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_app_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -255,11 +253,11 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(conversation_id: str):
db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
def del_conversation(session, conversation_id: str):
session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@ -270,28 +268,26 @@ def _delete_app_conversations(tenant_id: str, app_id: str):
def _delete_conversation_variables(*, app_id: str):
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
with db.engine.connect() as conn:
conn.execute(stmt)
conn.commit()
with session_factory.create_session() as session:
stmt = delete(ConversationVariable).where(ConversationVariable.app_id == app_id)
session.execute(stmt)
session.commit()
logger.info(click.style(f"Deleted conversation variables for app {app_id}", fg="green"))
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
def del_message(session, message_id: str):
session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
db.session.query(Message).where(Message.id == message_id).delete()
session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""",
@ -302,8 +298,8 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(tool_provider_id: str):
db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
def del_tool_provider(session, tool_provider_id: str):
session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@ -316,8 +312,8 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(tag_binding_id: str):
db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
def del_tag_binding(session, tag_binding_id: str):
session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@ -328,8 +324,8 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(end_user_id: str):
db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
def del_end_user(session, end_user_id: str):
session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -340,10 +336,8 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(trace_app_config_id: str):
db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
synchronize_session=False
)
def del_trace_app_config(session, trace_app_config_id: str):
session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from trace_app_config where app_id=:app_id limit 1000""",
@ -381,14 +375,14 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with db.engine.begin() as conn:
with session_factory.create_session() as session:
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
result = session.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
rows = list(result)
if not rows:
@ -399,7 +393,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
# Clean up associated Offload data first
if file_ids:
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
files_deleted = _delete_draft_variable_offload_data(session, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
@ -407,8 +401,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
DELETE FROM workflow_draft_variables
WHERE id IN :ids
"""
deleted_result = conn.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)})
batch_deleted = deleted_result.rowcount
deleted_result = cast(
CursorResult[Any],
session.execute(sa.text(delete_sql), {"ids": tuple(draft_var_ids)}),
)
batch_deleted: int = int(getattr(deleted_result, "rowcount", 0) or 0)
total_deleted += batch_deleted
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
@ -423,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
return total_deleted
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_draft_variable_offload_data(session, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
@ -434,7 +431,7 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
4. Deletes WorkflowDraftVariableFile records
Args:
conn: Database connection
session: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
@ -450,12 +447,12 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids
"""
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids \
"""
result = session.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
@ -473,17 +470,19 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
DELETE FROM upload_files
WHERE id IN :upload_file_ids
"""
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
DELETE \
FROM upload_files
WHERE id IN :upload_file_ids \
"""
session.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
DELETE FROM workflow_draft_variable_files
WHERE id IN :file_ids
"""
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
DELETE \
FROM workflow_draft_variable_files
WHERE id IN :file_ids \
"""
session.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
@ -493,8 +492,8 @@ def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
def _delete_app_triggers(tenant_id: str, app_id: str):
def del_app_trigger(trigger_id: str):
db.session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
def del_app_trigger(session, trigger_id: str):
session.query(AppTrigger).where(AppTrigger.id == trigger_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_triggers where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -505,8 +504,8 @@ def _delete_app_triggers(tenant_id: str, app_id: str):
def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def del_plugin_trigger(trigger_id: str):
db.session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
def del_plugin_trigger(session, trigger_id: str):
session.query(WorkflowPluginTrigger).where(WorkflowPluginTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -519,8 +518,8 @@ def _delete_workflow_plugin_triggers(tenant_id: str, app_id: str):
def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def del_webhook_trigger(trigger_id: str):
db.session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
def del_webhook_trigger(session, trigger_id: str):
session.query(WorkflowWebhookTrigger).where(WorkflowWebhookTrigger.id == trigger_id).delete(
synchronize_session=False
)
@ -533,10 +532,8 @@ def _delete_workflow_webhook_triggers(tenant_id: str, app_id: str):
def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def del_schedule_plan(plan_id: str):
db.session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(
synchronize_session=False
)
def del_schedule_plan(session, plan_id: str):
session.query(WorkflowSchedulePlan).where(WorkflowSchedulePlan.id == plan_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_schedule_plans where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -547,8 +544,8 @@ def _delete_workflow_schedule_plans(tenant_id: str, app_id: str):
def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def del_trigger_log(log_id: str):
db.session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
def del_trigger_log(session, log_id: str):
session.query(WorkflowTriggerLog).where(WorkflowTriggerLog.id == log_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflow_trigger_logs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@ -560,18 +557,22 @@ def _delete_workflow_trigger_logs(tenant_id: str, app_id: str):
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query_sql), params)
if rs.rowcount == 0:
with session_factory.create_session() as session:
rs = session.execute(sa.text(query_sql), params)
rows = rs.fetchall()
if not rows:
break
for i in rs:
for i in rows:
record_id = str(i.id)
try:
delete_func(record_id)
db.session.commit()
delete_func(session, record_id)
logger.info(click.style(f"Deleted {name} {record_id}", fg="green"))
except Exception:
logger.exception("Error occurred while deleting %s %s", name, record_id)
continue
# continue with next record even if one deletion fails
session.rollback()
break
session.commit()
rs.close()

View File

@ -5,8 +5,8 @@ import click
from celery import shared_task
from sqlalchemy import select
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Document, DocumentSegment
@ -25,52 +25,55 @@ def remove_document_from_index_task(document_id: str):
logger.info(click.style(f"Start remove document segments from index: {document_id}", fg="green"))
start_at = time.perf_counter()
document = db.session.query(Document).where(Document.id == document_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
db.session.close()
return
with session_factory.create_session() as session:
document = session.query(Document).where(Document.id == document_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status != "completed":
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
db.session.close()
return
if document.indexing_status != "completed":
logger.info(click.style(f"Document is not completed, remove is not allowed: {document_id}", fg="red"))
return
indexing_cache_key = f"document_{document.id}_indexing"
indexing_cache_key = f"document_{document.id}_indexing"
try:
dataset = document.dataset
try:
dataset = document.dataset
if not dataset:
raise Exception("Document has no dataset")
if not dataset:
raise Exception("Document has no dataset")
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: naive_utc_now(),
}
)
db.session.commit()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
except Exception:
logger.exception("clean dataset %s from index failed", dataset.id)
# update segment to disable
session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: naive_utc_now(),
DocumentSegment.disabled_by: document.disabled_by,
DocumentSegment.updated_at: naive_utc_now(),
}
)
session.commit()
end_at = time.perf_counter()
logger.info(click.style(f"Document removed from index: {document.id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
db.session.commit()
finally:
redis_client.delete(indexing_cache_key)
db.session.close()
end_at = time.perf_counter()
logger.info(
click.style(
f"Document removed from index: {document.id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("remove document from index failed")
if not document.archived:
document.enabled = True
session.commit()
finally:
redis_client.delete(indexing_cache_key)

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models import Account, Tenant
@ -29,97 +29,97 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
user = session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(e)
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(retry_indexing_cache_key)
return
logger.info(click.style(f"Start retry document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
session.add(document)
session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(retry_indexing_cache_key)
logger.exception("retry_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Retry dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
logger.exception(
"retry_document_indexing_task failed, dataset_id: %s, document_ids: %s", dataset_id, document_ids
)
raise e
finally:
db.session.close()
raise e

View File

@ -3,11 +3,11 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@ -27,69 +27,71 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
sync_indexing_cache_key = f"document_{document_id}_is_sync"
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
)
except Exception as e:
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
session.add(document)
session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.add(document)
session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(e)
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
redis_client.delete(sync_indexing_cache_key)
return
logger.info(click.style(f"Start sync website document: {document_id}", fg="green"))
document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="yellow"))
return
try:
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(sync_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"
document.error = str(ex)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))
session.add(document)
session.commit()
logger.info(click.style(str(ex), fg="yellow"))
redis_client.delete(sync_indexing_cache_key)
logger.exception("sync_website_document_indexing_task failed, document_id: %s", document_id)
end_at = time.perf_counter()
logger.info(click.style(f"Sync document: {document_id} latency: {end_at - start_at}", fg="green"))

View File

@ -16,6 +16,7 @@ from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.plugin.entities.request import TriggerInvokeEventResponse
from core.plugin.impl.exc import PluginInvokeError
@ -27,7 +28,6 @@ from core.trigger.trigger_manager import TriggerManager
from core.workflow.enums import NodeType, WorkflowExecutionStatus
from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.enums import (
AppTriggerType,
CreatorUserRole,
@ -257,7 +257,7 @@ def dispatch_triggered_workflow(
tenant_id=subscription.tenant_id, provider_id=TriggerProviderID(subscription.provider_id)
)
trigger_entity: TriggerProviderEntity = provider_controller.entity
with Session(db.engine) as session:
with session_factory.create_session() as session:
workflows: Mapping[str, Workflow] = _get_latest_workflows_by_app_ids(session, subscribers)
end_users: Mapping[str, EndUser] = EndUserService.create_end_user_batch(

View File

@ -7,9 +7,9 @@ from celery import shared_task
from sqlalchemy.orm import Session
from configs import dify_config
from core.db.session_factory import session_factory
from core.plugin.entities.plugin_daemon import CredentialType
from core.trigger.utils.locks import build_trigger_refresh_lock_key
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.trigger import TriggerSubscription
from services.trigger.trigger_provider_service import TriggerProviderService
@ -92,7 +92,7 @@ def trigger_subscription_refresh(tenant_id: str, subscription_id: str) -> None:
logger.info("Begin subscription refresh: tenant=%s id=%s", tenant_id, subscription_id)
try:
now: int = _now_ts()
with Session(db.engine) as session:
with session_factory.create_session() as session:
subscription: TriggerSubscription | None = _load_subscription(session, tenant_id, subscription_id)
if not subscription:

View File

@ -10,11 +10,10 @@ import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowRun
from models.enums import WorkflowRunTriggeredFrom
@ -46,10 +45,7 @@ def save_workflow_execution_task(
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowExecution.model_validate(execution_data)

View File

@ -10,13 +10,12 @@ import logging
from celery import shared_task
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from models import CreatorUserRole, WorkflowNodeExecutionModel
from models.workflow import WorkflowNodeExecutionTriggeredFrom
@ -48,10 +47,7 @@ def save_workflow_node_execution_task(
True if successful, False otherwise
"""
try:
# Create a new session for this task
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
# Deserialize execution data
execution = WorkflowNodeExecution.model_validate(execution_data)

View File

@ -1,15 +1,14 @@
import logging
from celery import shared_task
from sqlalchemy.orm import sessionmaker
from core.db.session_factory import session_factory
from core.workflow.nodes.trigger_schedule.exc import (
ScheduleExecutionError,
ScheduleNotFoundError,
TenantOwnerNotFoundError,
)
from enums.quota_type import QuotaType, unlimited
from extensions.ext_database import db
from models.trigger import WorkflowSchedulePlan
from services.async_workflow_service import AsyncWorkflowService
from services.errors.app import QuotaExceededError
@ -33,10 +32,7 @@ def run_schedule_trigger(schedule_id: str) -> None:
TenantOwnerNotFoundError: If no owner/admin for tenant
ScheduleExecutionError: If workflow trigger fails
"""
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_factory() as session:
with session_factory.create_session() as session:
schedule = session.get(WorkflowSchedulePlan, schedule_id)
if not schedule:
raise ScheduleNotFoundError(f"Schedule {schedule_id} not found")

View File

@ -4,8 +4,8 @@ from unittest.mock import patch
import pytest
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.variables.segments import StringSegment
from extensions.ext_database import db
from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
@ -16,26 +16,23 @@ from tasks.remove_app_and_related_data_task import _delete_draft_variables, dele
@pytest.fixture
def app_and_tenant(flask_req_ctx):
tenant_id = uuid.uuid4()
tenant = Tenant(
id=tenant_id,
name="test_tenant",
)
db.session.add(tenant)
with session_factory.create_session() as session:
tenant = Tenant(name="test_tenant")
session.add(tenant)
session.flush()
app = App(
tenant_id=tenant_id, # Now tenant.id will have a value
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app)
db.session.flush()
yield (tenant, app)
app = App(
tenant_id=tenant.id,
name=f"Test App for tenant {tenant.id}",
mode="workflow",
enable_site=True,
enable_api=True,
)
session.add(app)
session.flush()
# Cleanup with proper error handling
db.session.delete(app)
db.session.delete(tenant)
# return detached objects (ids will be used by tests)
return (tenant, app)
class TestDeleteDraftVariablesIntegration:
@ -44,334 +41,285 @@ class TestDeleteDraftVariablesIntegration:
"""Create test data with apps and draft variables."""
tenant, app = app_and_tenant
# Create a second app for testing
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.commit()
# Create draft variables for both apps
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var1)
variables_app1.append(var1)
session.add(app2)
session.flush()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var2)
variables_app2.append(var2)
variables_app1 = []
variables_app2 = []
for i in range(5):
var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var1)
variables_app1.append(var1)
# Commit all the variables to the database
db.session.commit()
var2 = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var2)
variables_app2.append(var2)
session.commit()
app2_id = app2.id
yield {
"app1": app,
"app2": app2,
"app2": App(id=app2_id), # dummy with id to avoid open session
"tenant": tenant,
"variables_app1": variables_app1,
"variables_app2": variables_app2,
}
# Cleanup - refresh session and check if objects still exist
db.session.rollback() # Clear any pending changes
# Clean up remaining variables
cleanup_query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.app_id.in_([app.id, app2.id]),
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id.in_([app.id, app2_id]))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_query)
# Clean up app2
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
session.execute(cleanup_query)
app2_obj = session.get(App, app2_id)
if app2_obj:
session.delete(app2_obj)
session.commit()
def test_delete_draft_variables_batch_removes_correct_variables(self, setup_test_data):
"""Test that batch deletion only removes variables for the specified app."""
data = setup_test_data
app1_id = data["app1"].id
app2_id = data["app2"].id
# Verify initial state
app1_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
with session_factory.create_session() as session:
app1_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_before == 5
assert app2_vars_before == 5
# Delete app1 variables
deleted_count = delete_draft_variables_batch(app1_id, batch_size=10)
# Verify results
assert deleted_count == 5
app1_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0 # All app1 variables deleted
assert app2_vars_after == 5 # App2 variables unchanged
with session_factory.create_session() as session:
app1_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
app2_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app2_id).count()
assert app1_vars_after == 0
assert app2_vars_after == 5
def test_delete_draft_variables_batch_with_small_batch_size(self, setup_test_data):
"""Test batch deletion with small batch size processes all records."""
data = setup_test_data
app1_id = data["app1"].id
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app1_id, batch_size=2)
assert deleted_count == 5
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
remaining_vars = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert remaining_vars == 0
def test_delete_draft_variables_batch_nonexistent_app(self, setup_test_data):
"""Test that deleting variables for nonexistent app returns 0."""
nonexistent_app_id = str(uuid.uuid4()) # Use a valid UUID format
nonexistent_app_id = str(uuid.uuid4())
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=100)
assert deleted_count == 0
def test_delete_draft_variables_wrapper_function(self, setup_test_data):
"""Test that _delete_draft_variables wrapper function works correctly."""
data = setup_test_data
app1_id = data["app1"].id
# Verify initial state
vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_before == 5
# Call wrapper function
deleted_count = _delete_draft_variables(app1_id)
# Verify results
assert deleted_count == 5
vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
with session_factory.create_session() as session:
vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app1_id).count()
assert vars_after == 0
def test_batch_deletion_handles_large_dataset(self, app_and_tenant):
"""Test batch deletion with larger dataset to verify batching logic."""
tenant, app = app_and_tenant
# Create many draft variables
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
db.session.add(var)
variables.append(var)
variable_ids = [i.id for i in variables]
# Commit the variables to the database
db.session.commit()
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(25):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
try:
# Use small batch size to force multiple batches
deleted_count = delete_draft_variables_batch(app.id, batch_size=8)
assert deleted_count == 25
# Verify all variables are deleted
remaining_vars = db.session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining_vars == 0
with session_factory.create_session() as session:
remaining = session.query(WorkflowDraftVariable).filter_by(app_id=app.id).count()
assert remaining == 0
finally:
query = (
delete(WorkflowDraftVariable)
.where(
WorkflowDraftVariable.id.in_(variable_ids),
with session_factory.create_session() as session:
query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
.execution_options(synchronize_session=False)
)
db.session.execute(query)
session.execute(query)
session.commit()
class TestDeleteDraftVariablesWithOffloadIntegration:
"""Integration tests for draft variable deletion with Offload data."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with draft variables that have associated Offload files."""
tenant, app = app_and_tenant
# Create UploadFile records
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
db.session.add(upload_file1)
db.session.add(upload_file2)
db.session.flush()
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
# Create WorkflowDraftVariableFile records
from core.variables.types import SegmentType
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
db.session.add(var_file1)
db.session.add(var_file2)
db.session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
# Create WorkflowDraftVariable records with file associations
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
# Create a regular variable without Offload data
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
db.session.add(draft_var1)
db.session.add(draft_var2)
db.session.add(draft_var3)
db.session.commit()
yield data
yield {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
# Cleanup
db.session.rollback()
# Clean up any remaining records
for table, ids in [
(WorkflowDraftVariable, [v.id for v in [draft_var1, draft_var2, draft_var3]]),
(WorkflowDraftVariableFile, [vf.id for vf in [var_file1, var_file2]]),
(UploadFile, [uf.id for uf in [upload_file1, upload_file2]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
db.session.execute(cleanup_query)
db.session.commit()
with session_factory.create_session() as session:
session.rollback()
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
"""Test that deleting draft variables also cleans up associated Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to succeed
mock_storage.delete.return_value = None
# Verify initial state
draft_vars_before = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = db.session.query(WorkflowDraftVariableFile).count()
upload_files_before = db.session.query(UploadFile).count()
assert draft_vars_before == 3 # 2 with files + 1 regular
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify results
assert deleted_count == 3
# Check that all draft variables are deleted
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Check that associated Offload data is cleaned up
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
assert var_files_after == 0 # All variable files should be deleted
assert upload_files_after == 0 # All upload files should be deleted
# Verify storage deletion was called for both files
assert mock_storage.delete.call_count == 2
storage_keys_deleted = [call.args[0] for call in mock_storage.delete.call_args_list]
assert "test/file1.json" in storage_keys_deleted
@ -379,92 +327,71 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that database cleanup continues even when storage deletion fails."""
data = setup_offload_test_data
app_id = data["app"].id
# Mock storage deletion to fail for first file, succeed for second
mock_storage.delete.side_effect = [Exception("Storage error"), None]
# Delete draft variables
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
# Verify that all draft variables are still deleted
assert deleted_count == 3
draft_vars_after = db.session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert draft_vars_after == 0
# Database cleanup should still succeed even with storage errors
var_files_after = db.session.query(WorkflowDraftVariableFile).count()
upload_files_after = db.session.query(UploadFile).count()
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage deletion was attempted for both files
assert mock_storage.delete.call_count == 2
@patch("extensions.ext_storage.storage")
def test_delete_draft_variables_partial_offload_data(self, mock_storage, setup_offload_test_data):
"""Test deletion with mix of variables with and without Offload data."""
data = setup_offload_test_data
app_id = data["app"].id
# Create additional app with only regular variables (no offload data)
tenant = data["tenant"]
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(app2)
db.session.flush()
# Add regular variables to app2
regular_vars = []
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
with session_factory.create_session() as session:
app2 = App(
tenant_id=tenant.id,
name="Test App 2",
mode="workflow",
enable_site=True,
enable_api=True,
)
db.session.add(var)
regular_vars.append(var)
db.session.commit()
session.add(app2)
session.flush()
for i in range(3):
var = WorkflowDraftVariable.new_node_variable(
app_id=app2.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
session.commit()
try:
# Mock storage deletion
mock_storage.delete.return_value = None
# Delete variables for app2 (no offload data)
deleted_count_app2 = delete_draft_variables_batch(app2.id, batch_size=10)
assert deleted_count_app2 == 3
# Verify storage wasn't called for app2 (no offload files)
mock_storage.delete.assert_not_called()
# Delete variables for original app (with offload data)
deleted_count_app1 = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count_app1 == 3
# Now storage should be called for the offload files
assert mock_storage.delete.call_count == 2
finally:
# Cleanup app2 and its variables
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
db.session.execute(cleanup_vars_query)
app2_obj = db.session.get(App, app2.id)
if app2_obj:
db.session.delete(app2_obj)
db.session.commit()
with session_factory.create_session() as session:
cleanup_vars_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.app_id == app2.id)
.execution_options(synchronize_session=False)
)
session.execute(cleanup_vars_query)
app2_obj = session.get(App, app2.id)
if app2_obj:
session.delete(app2_obj)
session.commit()

View File

@ -220,6 +220,23 @@ class TestAnnotationService:
# Note: In this test, no annotation setting exists, so task should not be called
mock_external_service_dependencies["add_task"].delay.assert_not_called()
def test_insert_app_annotation_directly_requires_question(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Question must be provided when inserting annotations directly.
"""
fake = Faker()
app, _ = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies)
annotation_args = {
"question": None,
"answer": fake.text(max_nb_chars=200),
}
with pytest.raises(ValueError):
AppAnnotationService.insert_app_annotation_directly(annotation_args, app.id)
def test_insert_app_annotation_directly_app_not_found(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -39,23 +39,22 @@ class TestCleanDatasetTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database before each test to ensure isolation."""
from extensions.ext_database import db
from extensions.ext_redis import redis_client
# Clear all test data
db.session.query(DatasetMetadataBinding).delete()
db.session.query(DatasetMetadata).delete()
db.session.query(AppDatasetJoin).delete()
db.session.query(DatasetQuery).delete()
db.session.query(DatasetProcessRule).delete()
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(UploadFile).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using the provided session fixture
db_session_with_containers.query(DatasetMetadataBinding).delete()
db_session_with_containers.query(DatasetMetadata).delete()
db_session_with_containers.query(AppDatasetJoin).delete()
db_session_with_containers.query(DatasetQuery).delete()
db_session_with_containers.query(DatasetProcessRule).delete()
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(UploadFile).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -103,10 +102,8 @@ class TestCleanDatasetTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -115,8 +112,8 @@ class TestCleanDatasetTask:
status="active",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account relationship
tenant_account_join = TenantAccountJoin(
@ -125,8 +122,8 @@ class TestCleanDatasetTask:
role=TenantAccountRole.OWNER,
)
db.session.add(tenant_account_join)
db.session.commit()
db_session_with_containers.add(tenant_account_join)
db_session_with_containers.commit()
return account, tenant
@ -155,10 +152,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -194,10 +189,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -232,10 +225,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segment
@ -267,10 +258,8 @@ class TestCleanDatasetTask:
used=False,
)
from extensions.ext_database import db
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
return upload_file
@ -302,31 +291,29 @@ class TestCleanDatasetTask:
)
# Verify results
from extensions.ext_database import db
# Check that dataset-related data was cleaned up
documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(documents) == 0
segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(segments) == 0
# Check that metadata and bindings were cleaned up
metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(metadata) == 0
bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
bindings = db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
assert len(bindings) == 0
# Check that process rules and queries were cleaned up
process_rules = db.session.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
process_rules = db_session_with_containers.query(DatasetProcessRule).filter_by(dataset_id=dataset.id).all()
assert len(process_rules) == 0
queries = db.session.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
queries = db_session_with_containers.query(DatasetQuery).filter_by(dataset_id=dataset.id).all()
assert len(queries) == 0
# Check that app dataset joins were cleaned up
app_joins = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
app_joins = db_session_with_containers.query(AppDatasetJoin).filter_by(dataset_id=dataset.id).all()
assert len(app_joins) == 0
# Verify index processor was called
@ -378,9 +365,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Create dataset metadata and bindings
metadata = DatasetMetadata(
@ -403,11 +388,9 @@ class TestCleanDatasetTask:
binding.id = str(uuid.uuid4())
binding.created_at = datetime.now()
from extensions.ext_database import db
db.session.add(metadata)
db.session.add(binding)
db.session.commit()
db_session_with_containers.add(metadata)
db_session_with_containers.add(binding)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -421,22 +404,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that metadata and bindings were cleaned up
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify index processor was called
@ -489,12 +474,13 @@ class TestCleanDatasetTask:
mock_index_processor.clean.assert_called_once()
# Check that all data was cleaned up
from extensions.ext_database import db
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = (
db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_segments) == 0
# Recreate data for next test case
@ -540,14 +526,13 @@ class TestCleanDatasetTask:
)
# Verify results - even with vector cleanup failure, documents and segments should be deleted
from extensions.ext_database import db
# Check that documents were still deleted despite vector cleanup failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite vector cleanup failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Verify that index processor was called and failed
@ -608,10 +593,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Mock the get_image_upload_file_ids function to return our image file IDs
with patch("tasks.clean_dataset_task.get_image_upload_file_ids") as mock_get_image_ids:
@ -629,16 +612,18 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all image files were deleted from database
image_file_ids = [f.id for f in image_files]
remaining_image_files = db.session.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
remaining_image_files = (
db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(image_file_ids)).all()
)
assert len(remaining_image_files) == 0
# Verify that storage.delete was called for each image file
@ -745,22 +730,24 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
remaining_files = db_session_with_containers.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).all()
assert len(remaining_files) == 0
# Check that all metadata and bindings were deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
remaining_bindings = db.session.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
remaining_bindings = (
db_session_with_containers.query(DatasetMetadataBinding).filter_by(dataset_id=dataset.id).all()
)
assert len(remaining_bindings) == 0
# Verify performance expectations
@ -808,9 +795,7 @@ class TestCleanDatasetTask:
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
from extensions.ext_database import db
db.session.commit()
db_session_with_containers.commit()
# Mock storage to raise exceptions
mock_storage = mock_external_service_dependencies["storage"]
@ -827,18 +812,13 @@ class TestCleanDatasetTask:
)
# Verify results
# Check that documents were still deleted despite storage failure
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that segments were still deleted despite storage failure
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Note: When storage operations fail, database deletions may be rolled back by implementation.
# This test focuses on ensuring the task handles the exception and continues execution/logging.
# Check that upload file was still deleted from database despite storage failure
# Note: When storage operations fail, the upload file may not be deleted
# This demonstrates that the cleanup process continues even with storage errors
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file.id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file.id).all()
# The upload file should still be deleted from the database even if storage cleanup fails
# However, this depends on the specific implementation of clean_dataset_task
if len(remaining_files) > 0:
@ -890,10 +870,8 @@ class TestCleanDatasetTask:
updated_at=datetime.now(),
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create document with special characters in name
special_content = "Special chars: !@#$%^&*()_+-=[]{}|;':\",./<>?`~"
@ -912,8 +890,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create segment with special characters and very long content
long_content = "Very long content " * 100 # Long content within reasonable limits
@ -934,8 +912,8 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
updated_at=datetime.now(),
)
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
# Create upload file with special characters in name
special_filename = f"test_file_{special_content}.txt"
@ -952,14 +930,14 @@ class TestCleanDatasetTask:
created_at=datetime.now(),
used=False,
)
db.session.add(upload_file)
db.session.commit()
db_session_with_containers.add(upload_file)
db_session_with_containers.commit()
# Update document with file reference
import json
document.data_source_info = json.dumps({"upload_file_id": upload_file.id})
db.session.commit()
db_session_with_containers.commit()
# Save upload file ID for verification
upload_file_id = upload_file.id
@ -975,8 +953,8 @@ class TestCleanDatasetTask:
special_metadata.id = str(uuid.uuid4())
special_metadata.created_at = datetime.now()
db.session.add(special_metadata)
db.session.commit()
db_session_with_containers.add(special_metadata)
db_session_with_containers.commit()
# Execute the task
clean_dataset_task(
@ -990,19 +968,19 @@ class TestCleanDatasetTask:
# Verify results
# Check that all documents were deleted
remaining_documents = db.session.query(Document).filter_by(dataset_id=dataset.id).all()
remaining_documents = db_session_with_containers.query(Document).filter_by(dataset_id=dataset.id).all()
assert len(remaining_documents) == 0
# Check that all segments were deleted
remaining_segments = db.session.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
remaining_segments = db_session_with_containers.query(DocumentSegment).filter_by(dataset_id=dataset.id).all()
assert len(remaining_segments) == 0
# Check that all upload files were deleted
remaining_files = db.session.query(UploadFile).filter_by(id=upload_file_id).all()
remaining_files = db_session_with_containers.query(UploadFile).filter_by(id=upload_file_id).all()
assert len(remaining_files) == 0
# Check that all metadata was deleted
remaining_metadata = db.session.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
remaining_metadata = db_session_with_containers.query(DatasetMetadata).filter_by(dataset_id=dataset.id).all()
assert len(remaining_metadata) == 0
# Verify that storage.delete was called

View File

@ -24,16 +24,15 @@ class TestCreateSegmentToIndexTask:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers):
"""Clean up database and Redis before each test to ensure isolation."""
from extensions.ext_database import db
# Clear all test data
db.session.query(DocumentSegment).delete()
db.session.query(Document).delete()
db.session.query(Dataset).delete()
db.session.query(TenantAccountJoin).delete()
db.session.query(Tenant).delete()
db.session.query(Account).delete()
db.session.commit()
# Clear all test data using fixture session
db_session_with_containers.query(DocumentSegment).delete()
db_session_with_containers.query(Document).delete()
db_session_with_containers.query(Dataset).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
# Clear Redis cache
redis_client.flushdb()
@ -73,10 +72,8 @@ class TestCreateSegmentToIndexTask:
status="active",
)
from extensions.ext_database import db
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Create tenant
tenant = Tenant(
@ -84,8 +81,8 @@ class TestCreateSegmentToIndexTask:
status="normal",
plan="basic",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join with owner role
join = TenantAccountJoin(
@ -94,8 +91,8 @@ class TestCreateSegmentToIndexTask:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Set current tenant for account
account.current_tenant = tenant
@ -746,20 +743,9 @@ class TestCreateSegmentToIndexTask:
db_session_with_containers, dataset.id, document.id, tenant.id, account.id, status="waiting"
)
# Mock global database session to simulate transaction issues
from extensions.ext_database import db
original_commit = db.session.commit
commit_called = False
def mock_commit():
nonlocal commit_called
if not commit_called:
commit_called = True
raise Exception("Database commit failed")
return original_commit()
db.session.commit = mock_commit
# Simulate an error during indexing to trigger rollback path
mock_processor = mock_external_service_dependencies["index_processor"]
mock_processor.load.side_effect = Exception("Simulated indexing error")
# Act: Execute the task
create_segment_to_index_task(segment.id)
@ -771,9 +757,6 @@ class TestCreateSegmentToIndexTask:
assert segment.disabled_at is not None
assert segment.error is not None
# Restore original commit method
db.session.commit = original_commit
def test_create_segment_to_index_metadata_validation(
self, db_session_with_containers, mock_external_service_dependencies
):

View File

@ -70,11 +70,9 @@ class TestDisableSegmentsFromIndexTask:
tenant.created_at = fake.date_time_this_year()
tenant.updated_at = tenant.created_at
from extensions.ext_database import db
db.session.add(tenant)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.add(account)
db_session_with_containers.commit()
# Set the current tenant for the account
account.current_tenant = tenant
@ -110,10 +108,8 @@ class TestDisableSegmentsFromIndexTask:
built_in_field_enabled=False,
)
from extensions.ext_database import db
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
return dataset
@ -158,10 +154,8 @@ class TestDisableSegmentsFromIndexTask:
document.archived = False
document.doc_form = "text_model" # Use text_model form for testing
document.doc_language = "en"
from extensions.ext_database import db
db.session.add(document)
db.session.commit()
db_session_with_containers.add(document)
db_session_with_containers.commit()
return document
@ -211,11 +205,9 @@ class TestDisableSegmentsFromIndexTask:
segments.append(segment)
from extensions.ext_database import db
for segment in segments:
db.session.add(segment)
db.session.commit()
db_session_with_containers.add(segment)
db_session_with_containers.commit()
return segments
@ -645,15 +637,12 @@ class TestDisableSegmentsFromIndexTask:
with patch("tasks.disable_segments_from_index_task.redis_client") as mock_redis:
mock_redis.delete.return_value = True
# Mock db.session.close to verify it's called
with patch("tasks.disable_segments_from_index_task.db.session.close") as mock_close:
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Act
result = disable_segments_from_index_task(segment_ids, dataset.id, document.id)
# Assert
assert result is None # Task should complete without returning a value
# Verify session was closed
mock_close.assert_called()
# Assert
assert result is None # Task should complete without returning a value
# Session lifecycle is managed by context manager; no explicit close assertion
def test_disable_segments_empty_segment_ids(self, db_session_with_containers):
"""

View File

@ -6,7 +6,6 @@ from faker import Faker
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document
from tasks.document_indexing_task import (
@ -75,15 +74,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -92,8 +91,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -105,8 +104,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -124,13 +123,13 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -157,15 +156,15 @@ class TestDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -174,8 +173,8 @@ class TestDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -187,8 +186,8 @@ class TestDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -206,10 +205,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -219,7 +218,7 @@ class TestDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -242,6 +241,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -250,7 +252,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -310,6 +312,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_document_indexing(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -317,7 +322,7 @@ class TestDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _document_indexing uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -353,6 +358,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -361,7 +369,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -400,7 +408,7 @@ class TestDocumentIndexingTasks:
indexing_status="completed", # Already completed
enabled=True,
)
db.session.add(doc1)
db_session_with_containers.add(doc1)
extra_documents.append(doc1)
# Document with disabled status
@ -417,10 +425,10 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=False, # Disabled
)
db.session.add(doc2)
db_session_with_containers.add(doc2)
extra_documents.append(doc2)
db.session.commit()
db_session_with_containers.commit()
all_documents = base_documents + extra_documents
document_ids = [doc.id for doc in all_documents]
@ -428,6 +436,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with mixed document states
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -435,7 +446,7 @@ class TestDocumentIndexingTasks:
# Verify all documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -482,20 +493,23 @@ class TestDocumentIndexingTasks:
indexing_status="waiting",
enabled=True,
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error
@ -526,6 +540,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task with billing disabled
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify successful processing
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -533,7 +550,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -565,6 +582,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the task
_document_indexing(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -573,7 +593,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -674,6 +694,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred (same as _document_indexing)
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -681,7 +704,7 @@ class TestDocumentIndexingTasks:
# Verify documents were updated (same as _document_indexing)
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -794,6 +817,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function
_document_indexing_with_tenant_queue(tenant_id, dataset.id, document_ids, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error was handled gracefully
# The function should not raise exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -802,7 +828,7 @@ class TestDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _document_indexing uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -865,6 +891,9 @@ class TestDocumentIndexingTasks:
# Act: Execute the wrapper function for tenant1 only
_document_indexing_with_tenant_queue(tenant1_id, dataset1.id, document_ids1, mock_task_func)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify core processing occurred for tenant1
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()

View File

@ -4,7 +4,6 @@ import pytest
from faker import Faker
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.duplicate_document_indexing_task import (
@ -82,15 +81,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -99,8 +98,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -112,8 +111,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -132,13 +131,13 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -183,14 +182,14 @@ class TestDuplicateDocumentIndexingTasks:
indexing_at=fake.date_time_this_year(),
created_by=dataset.created_by, # Add required field
)
db.session.add(segment)
db_session_with_containers.add(segment)
segments.append(segment)
db.session.commit()
db_session_with_containers.commit()
# Refresh to ensure all relationships are loaded
for document in documents:
db.session.refresh(document)
db_session_with_containers.refresh(document)
return dataset, documents, segments
@ -217,15 +216,15 @@ class TestDuplicateDocumentIndexingTasks:
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(
name=fake.company(),
status="normal",
)
db.session.add(tenant)
db.session.commit()
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# Create tenant-account join
join = TenantAccountJoin(
@ -234,8 +233,8 @@ class TestDuplicateDocumentIndexingTasks:
role=TenantAccountRole.OWNER,
current=True,
)
db.session.add(join)
db.session.commit()
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Create dataset
dataset = Dataset(
@ -247,8 +246,8 @@ class TestDuplicateDocumentIndexingTasks:
indexing_technique="high_quality",
created_by=account.id,
)
db.session.add(dataset)
db.session.commit()
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Create documents
documents = []
@ -267,10 +266,10 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
documents.append(document)
db.session.commit()
db_session_with_containers.commit()
# Configure billing features
mock_external_service_dependencies["features"].billing.enabled = billing_enabled
@ -280,7 +279,7 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["features"].vector_space.size = 50
# Refresh dataset to ensure it's properly loaded
db.session.refresh(dataset)
db_session_with_containers.refresh(dataset)
return dataset, documents
@ -305,6 +304,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify the expected outcomes
# Verify indexing runner was called correctly
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -313,7 +315,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were updated to parsing status
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -340,23 +342,32 @@ class TestDuplicateDocumentIndexingTasks:
db_session_with_containers, mock_external_service_dependencies, document_count=2, segments_per_doc=3
)
document_ids = [doc.id for doc in documents]
segment_ids = [seg.id for seg in segments]
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
db_session_with_containers.expire_all()
# Assert: Verify segment cleanup
# Verify index processor clean was called for each document with segments
assert mock_external_service_dependencies["index_processor"].clean.call_count == len(documents)
# Verify segments were deleted from database
# Re-query segments from database since _duplicate_document_indexing_task uses a different session
for segment in segments:
deleted_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
# Re-query segments from database using captured IDs to avoid stale ORM instances
for seg_id in segment_ids:
deleted_segment = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.id == seg_id).first()
)
assert deleted_segment is None
# Verify documents were updated to parsing status
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -415,6 +426,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with mixed document IDs
_duplicate_document_indexing_task(dataset.id, all_document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify only existing documents were processed
mock_external_service_dependencies["indexing_runner"].assert_called_once()
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
@ -422,7 +436,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify only existing documents were updated
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in existing_document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -458,6 +472,9 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify exception was handled gracefully
# The task should complete without raising exceptions
mock_external_service_dependencies["indexing_runner"].assert_called_once()
@ -466,7 +483,7 @@ class TestDuplicateDocumentIndexingTasks:
# Verify documents were still updated to parsing status before the exception
# Re-query documents from database since _duplicate_document_indexing_task close the session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
assert updated_document.processing_started_at is not None
@ -508,20 +525,23 @@ class TestDuplicateDocumentIndexingTasks:
enabled=True,
doc_form="text_model",
)
db.session.add(document)
db_session_with_containers.add(document)
extra_documents.append(document)
db.session.commit()
db_session_with_containers.commit()
all_documents = documents + extra_documents
document_ids = [doc.id for doc in all_documents]
# Act: Execute the task with too many documents for sandbox plan
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "batch upload" in updated_document.error.lower()
@ -557,10 +577,13 @@ class TestDuplicateDocumentIndexingTasks:
# Act: Execute the task with documents that will exceed vector space limit
_duplicate_document_indexing_task(dataset.id, document_ids)
# Ensure we see committed changes from a different session
db_session_with_containers.expire_all()
# Assert: Verify error handling
# Re-query documents from database since _duplicate_document_indexing_task uses a different session
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "error"
assert updated_document.error is not None
assert "limit" in updated_document.error.lower()
@ -620,11 +643,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_external_service_dependencies["indexing_runner_instance"].run.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -663,11 +686,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")
@ -707,11 +730,11 @@ class TestDuplicateDocumentIndexingTasks:
mock_queue.delete_task_key.assert_called_once()
# Clear session cache to see database updates from task's session
db.session.expire_all()
db_session_with_containers.expire_all()
# Verify documents were processed
for doc_id in document_ids:
updated_document = db.session.query(Document).where(Document.id == doc_id).first()
updated_document = db_session_with_containers.query(Document).where(Document.id == doc_id).first()
assert updated_document.indexing_status == "parsing"
@patch("tasks.duplicate_document_indexing_task.TenantIsolatedTaskQueue")

View File

@ -5,6 +5,7 @@ from typing import Any
from unittest.mock import MagicMock
import pytest
from pydantic import BaseModel
from core.workflow.context.execution_context import (
AppContext,
@ -12,6 +13,8 @@ from core.workflow.context.execution_context import (
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
read_context,
register_context,
)
@ -256,3 +259,31 @@ class TestCaptureCurrentContext:
# Context variables should be captured
assert result.context_vars is not None
class TestTenantScopedContextRegistry:
def setup_method(self):
from core.workflow.context import reset_context_provider
reset_context_provider()
def teardown_method(self):
from core.workflow.context import reset_context_provider
reset_context_provider()
def test_tenant_provider_read_ok(self):
class SandboxContext(BaseModel):
base_url: str | None = None
register_context("workflow.sandbox", "t1", lambda: SandboxContext(base_url="http://t1"))
register_context("workflow.sandbox", "t2", lambda: SandboxContext(base_url="http://t2"))
assert read_context("workflow.sandbox", tenant_id="t1").base_url == "http://t1"
assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2"
def test_missing_provider_raises_keyerror(self):
from core.workflow.context import ContextProviderNotFoundError
with pytest.raises(ContextProviderNotFoundError):
read_context("missing", tenant_id="unknown")

View File

@ -49,10 +49,14 @@ def pipeline_id():
@pytest.fixture
def mock_db_session():
"""Mock database session with query capabilities."""
with patch("tasks.clean_dataset_task.db") as mock_db:
"""Mock database session via session_factory.create_session()."""
with patch("tasks.clean_dataset_task.session_factory") as mock_sf:
mock_session = MagicMock()
mock_db.session = mock_session
# context manager for create_session()
cm = MagicMock()
cm.__enter__.return_value = mock_session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
# Setup query chain
mock_query = MagicMock()
@ -66,7 +70,10 @@ def mock_db_session():
# Setup execute for JOIN queries
mock_session.execute.return_value.all.return_value = []
yield mock_db
# Yield an object with a `.session` attribute to keep tests unchanged
wrapper = MagicMock()
wrapper.session = mock_session
yield wrapper
@pytest.fixture
@ -227,7 +234,9 @@ class TestBasicCleanup:
# Assert
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_deletes_related_records(
@ -413,7 +422,9 @@ class TestErrorHandling:
# Assert - documents and segments should still be deleted
mock_db_session.session.delete.assert_any_call(mock_document)
mock_db_session.session.delete.assert_any_call(mock_segment)
# Segments are deleted in batch; verify a DELETE on document_segments was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
mock_db_session.session.commit.assert_called_once()
def test_clean_dataset_task_storage_delete_failure_continues(
@ -461,7 +472,7 @@ class TestErrorHandling:
[mock_segment], # segments
]
mock_get_image_upload_file_ids.return_value = [image_file_id]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
mock_storage.delete.side_effect = Exception("Storage service unavailable")
# Act
@ -476,8 +487,9 @@ class TestErrorHandling:
# Assert - storage delete was attempted for image file
mock_storage.delete.assert_called_with(mock_upload_file.key)
# Image file should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_database_error_rollback(
self,
@ -691,8 +703,10 @@ class TestSegmentAttachmentCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_attachment_file.key)
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Attachment file and binding are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
def test_clean_dataset_task_attachment_storage_failure(
self,
@ -734,9 +748,10 @@ class TestSegmentAttachmentCleanup:
# Assert - storage delete was attempted
mock_storage.delete.assert_called_once()
# Records should still be deleted from database
mock_db_session.session.delete.assert_any_call(mock_attachment_file)
mock_db_session.session.delete.assert_any_call(mock_binding)
# Records are deleted in batch; verify DELETEs were issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
assert any("DELETE FROM segment_attachment_bindings" in sql for sql in execute_sqls)
# ============================================================================
@ -784,7 +799,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = mock_upload_file
mock_db_session.session.query.return_value.where.return_value.all.return_value = [mock_upload_file]
# Act
clean_dataset_task(
@ -798,7 +813,9 @@ class TestUploadFileCleanup:
# Assert
mock_storage.delete.assert_called_with(mock_upload_file.key)
mock_db_session.session.delete.assert_any_call(mock_upload_file)
# Upload files are deleted in batch; verify a DELETE on upload_files was issued
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM upload_files" in sql for sql in execute_sqls)
def test_clean_dataset_task_handles_missing_upload_file(
self,
@ -832,7 +849,7 @@ class TestUploadFileCleanup:
[mock_document], # documents
[], # segments
]
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@ -949,11 +966,11 @@ class TestImageFileCleanup:
[mock_segment], # segments
]
# Setup a mock query chain that returns files in sequence
# Setup a mock query chain that returns files in batch (align with .in_().all())
mock_query = MagicMock()
mock_where = MagicMock()
mock_query.where.return_value = mock_where
mock_where.first.side_effect = mock_image_files
mock_where.all.return_value = mock_image_files
mock_db_session.session.query.return_value = mock_query
# Act
@ -966,10 +983,10 @@ class TestImageFileCleanup:
doc_form="paragraph_index",
)
# Assert
assert mock_storage.delete.call_count == 2
mock_storage.delete.assert_any_call("images/image-1.jpg")
mock_storage.delete.assert_any_call("images/image-2.jpg")
# Assert - each expected image key was deleted at least once
calls = [c.args[0] for c in mock_storage.delete.call_args_list]
assert "images/image-1.jpg" in calls
assert "images/image-2.jpg" in calls
def test_clean_dataset_task_handles_missing_image_file(
self,
@ -1010,7 +1027,7 @@ class TestImageFileCleanup:
]
# Image file not found
mock_db_session.session.query.return_value.where.return_value.first.return_value = None
mock_db_session.session.query.return_value.where.return_value.all.return_value = []
# Act - should not raise exception
clean_dataset_task(
@ -1086,14 +1103,15 @@ class TestEdgeCases:
doc_form="paragraph_index",
)
# Assert - all documents and segments should be deleted
# Assert - all documents and segments should be deleted (documents per-entity, segments in batch)
delete_calls = mock_db_session.session.delete.call_args_list
deleted_items = [call[0][0] for call in delete_calls]
for doc in mock_documents:
assert doc in deleted_items
for seg in mock_segments:
assert seg in deleted_items
# Verify a batch DELETE on document_segments occurred
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
def test_clean_dataset_task_document_with_empty_data_source_info(
self,

View File

@ -81,12 +81,25 @@ def mock_documents(document_ids, dataset_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests that expect session.close() to be called can observe it via the context manager
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture

View File

@ -18,12 +18,18 @@ from tasks.delete_account_task import delete_account_task
@pytest.fixture
def mock_db_session():
"""Mock the db.session used in delete_account_task."""
with patch("tasks.delete_account_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
yield mock_session
"""Mock session via session_factory.create_session()."""
with patch("tasks.delete_account_task.session_factory") as mock_sf:
session = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
cm.__exit__.return_value = None
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
yield session
@pytest.fixture

View File

@ -109,13 +109,25 @@ def mock_document_segments(document_id):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.document_indexing_sync_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_sync_task.session_factory") as mock_sf:
session = MagicMock()
# Ensure tests can observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@ -251,8 +263,8 @@ class TestDocumentIndexingSyncTask:
# Assert
# Document status should remain unchanged
assert mock_document.indexing_status == "completed"
# No session operations should be performed beyond the initial query
mock_db_session.close.assert_not_called()
# Session should still be closed via context manager teardown
assert mock_db_session.close.called
def test_successful_sync_when_page_updated(
self,
@ -286,9 +298,9 @@ class TestDocumentIndexingSyncTask:
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_called_once()
# Verify segments were deleted from database
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted from database in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# Verify indexing runner was called
mock_indexing_runner.run.assert_called_once_with([mock_document])

View File

@ -94,13 +94,25 @@ def mock_document_segments(document_ids):
@pytest.fixture
def mock_db_session():
"""Mock database session."""
with patch("tasks.duplicate_document_indexing_task.db.session") as mock_session:
mock_query = MagicMock()
mock_session.query.return_value = mock_query
mock_query.where.return_value = mock_query
mock_session.scalars.return_value = MagicMock()
yield mock_session
"""Mock database session via session_factory.create_session()."""
with patch("tasks.duplicate_document_indexing_task.session_factory") as mock_sf:
session = MagicMock()
# Allow tests to observe session.close() via context manager teardown
session.close = MagicMock()
cm = MagicMock()
cm.__enter__.return_value = session
def _exit_side_effect(*args, **kwargs):
session.close()
cm.__exit__.side_effect = _exit_side_effect
mock_sf.create_session.return_value = cm
query = MagicMock()
session.query.return_value = query
query.where.return_value = query
session.scalars.return_value = MagicMock()
yield session
@pytest.fixture
@ -200,8 +212,25 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test successful duplicate document indexing flow."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
# Dataset via query.first()
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# scalars() call sequence:
# 1) documents list
# 2..N) segments per document
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
# First call returns documents; subsequent calls return segments
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
# Act
_duplicate_document_indexing_task(dataset_id, document_ids)
@ -264,8 +293,21 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when billing limit is exceeded."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = [] # No segments to clean
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# First scalars() -> documents; subsequent -> empty segments
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_features = mock_feature_service.get_features.return_value
mock_features.billing.enabled = True
mock_features.billing.subscription.plan = CloudPlan.TEAM
@ -294,8 +336,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when IndexingRunner raises an error."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = Exception("Indexing error")
# Act
@ -318,8 +372,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test duplicate document indexing when document is paused."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = []
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document paused")
# Act
@ -343,8 +409,20 @@ class TestDuplicateDocumentIndexingTaskCore:
):
"""Test that duplicate document indexing cleans old segments."""
# Arrange
mock_db_session.query.return_value.where.return_value.first.side_effect = [mock_dataset] + mock_documents
mock_db_session.scalars.return_value.all.return_value = mock_document_segments
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
def _scalars_side_effect(*args, **kwargs):
m = MagicMock()
if not hasattr(_scalars_side_effect, "_calls"):
_scalars_side_effect._calls = 0
if _scalars_side_effect._calls == 0:
m.all.return_value = mock_documents
else:
m.all.return_value = mock_document_segments
_scalars_side_effect._calls += 1
return m
mock_db_session.scalars.side_effect = _scalars_side_effect
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
# Act
@ -354,9 +432,9 @@ class TestDuplicateDocumentIndexingTaskCore:
# Verify clean was called for each document
assert mock_processor.clean.call_count == len(mock_documents)
# Verify segments were deleted
for segment in mock_document_segments:
mock_db_session.delete.assert_any_call(segment)
# Verify segments were deleted in batch (DELETE FROM document_segments)
execute_sqls = [" ".join(str(c[0][0]).split()) for c in mock_db_session.execute.call_args_list]
assert any("DELETE FROM document_segments" in sql for sql in execute_sqls)
# ============================================================================

View File

@ -11,21 +11,18 @@ from tasks.remove_app_and_related_data_task import (
class TestDeleteDraftVariablesBatch:
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_success(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_success(self, mock_sf, mock_offload_cleanup):
"""Test successful deletion of draft variables in batches."""
app_id = "test-app-id"
batch_size = 100
# Mock database connection and engine
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock two batches of results, then empty
batch1_data = [(f"var-{i}", f"file-{i}" if i % 2 == 0 else None) for i in range(100)]
@ -68,7 +65,7 @@ class TestDeleteDraftVariablesBatch:
select_result3.__iter__.return_value = iter([])
# Configure side effects in the correct order
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
select_result1, # First SELECT
delete_result1, # First DELETE
select_result2, # Second SELECT
@ -86,54 +83,49 @@ class TestDeleteDraftVariablesBatch:
assert result == 150
# Verify database calls
assert mock_conn.execute.call_count == 5 # 3 selects + 2 deletes
assert mock_session.execute.call_count == 5 # 3 selects + 2 deletes
# Verify offload cleanup was called for both batches with file_ids
expected_offload_calls = [call(mock_conn, batch1_file_ids), call(mock_conn, batch2_file_ids)]
expected_offload_calls = [call(mock_session, batch1_file_ids), call(mock_session, batch2_file_ids)]
mock_offload_cleanup.assert_has_calls(expected_offload_calls)
# Simplified verification - check that the right number of calls were made
# and that the SQL queries contain the expected patterns
actual_calls = mock_conn.execute.call_args_list
actual_calls = mock_session.execute.call_args_list
for i, actual_call in enumerate(actual_calls):
sql_text = str(actual_call[0][0])
normalized = " ".join(sql_text.split())
if i % 2 == 0: # SELECT calls (even indices: 0, 2, 4)
# Verify it's a SELECT query that now includes file_id
sql_text = str(actual_call[0][0])
assert "SELECT id, file_id FROM workflow_draft_variables" in sql_text
assert "WHERE app_id = :app_id" in sql_text
assert "LIMIT :batch_size" in sql_text
assert "SELECT id, file_id FROM workflow_draft_variables" in normalized
assert "WHERE app_id = :app_id" in normalized
assert "LIMIT :batch_size" in normalized
else: # DELETE calls (odd indices: 1, 3)
# Verify it's a DELETE query
sql_text = str(actual_call[0][0])
assert "DELETE FROM workflow_draft_variables" in sql_text
assert "WHERE id IN :ids" in sql_text
assert "DELETE FROM workflow_draft_variables" in normalized
assert "WHERE id IN :ids" in normalized
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
def test_delete_draft_variables_batch_empty_result(self, mock_db, mock_offload_cleanup):
@patch("tasks.remove_app_and_related_data_task.session_factory")
def test_delete_draft_variables_batch_empty_result(self, mock_sf, mock_offload_cleanup):
"""Test deletion when no draft variables exist for the app."""
app_id = "nonexistent-app-id"
batch_size = 1000
# Mock database connection
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock empty result
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.return_value = empty_result
mock_session.execute.return_value = empty_result
result = delete_draft_variables_batch(app_id, batch_size)
assert result == 0
assert mock_conn.execute.call_count == 1 # Only one select query
assert mock_session.execute.call_count == 1 # Only one select query
mock_offload_cleanup.assert_not_called() # No files to clean up
def test_delete_draft_variables_batch_invalid_batch_size(self):
@ -147,22 +139,19 @@ class TestDeleteDraftVariablesBatch:
delete_draft_variables_batch(app_id, 0)
@patch("tasks.remove_app_and_related_data_task._delete_draft_variable_offload_data")
@patch("tasks.remove_app_and_related_data_task.db")
@patch("tasks.remove_app_and_related_data_task.session_factory")
@patch("tasks.remove_app_and_related_data_task.logger")
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_db, mock_offload_cleanup):
def test_delete_draft_variables_batch_logs_progress(self, mock_logging, mock_sf, mock_offload_cleanup):
"""Test that batch deletion logs progress correctly."""
app_id = "test-app-id"
batch_size = 50
# Mock database
mock_conn = MagicMock()
mock_engine = MagicMock()
mock_db.engine = mock_engine
# Properly mock the context manager
# Mock session via session_factory
mock_session = MagicMock()
mock_context_manager = MagicMock()
mock_context_manager.__enter__.return_value = mock_conn
mock_context_manager.__enter__.return_value = mock_session
mock_context_manager.__exit__.return_value = None
mock_engine.begin.return_value = mock_context_manager
mock_sf.create_session.return_value = mock_context_manager
# Mock one batch then empty
batch_data = [(f"var-{i}", f"file-{i}" if i % 3 == 0 else None) for i in range(30)]
@ -183,7 +172,7 @@ class TestDeleteDraftVariablesBatch:
empty_result = MagicMock()
empty_result.__iter__.return_value = iter([])
mock_conn.execute.side_effect = [
mock_session.execute.side_effect = [
# Select query result
select_result,
# Delete query result
@ -201,7 +190,7 @@ class TestDeleteDraftVariablesBatch:
# Verify offload cleanup was called with file_ids
if batch_file_ids:
mock_offload_cleanup.assert_called_once_with(mock_conn, batch_file_ids)
mock_offload_cleanup.assert_called_once_with(mock_session, batch_file_ids)
# Verify logging calls
assert mock_logging.info.call_count == 2
@ -261,19 +250,19 @@ class TestDeleteDraftVariableOffloadData:
actual_calls = mock_conn.execute.call_args_list
# First call should be the SELECT query
select_call_sql = str(actual_calls[0][0][0])
select_call_sql = " ".join(str(actual_calls[0][0][0]).split())
assert "SELECT wdvf.id, uf.key, uf.id as upload_file_id" in select_call_sql
assert "FROM workflow_draft_variable_files wdvf" in select_call_sql
assert "JOIN upload_files uf ON wdvf.upload_file_id = uf.id" in select_call_sql
assert "WHERE wdvf.id IN :file_ids" in select_call_sql
# Second call should be DELETE upload_files
delete_upload_call_sql = str(actual_calls[1][0][0])
delete_upload_call_sql = " ".join(str(actual_calls[1][0][0]).split())
assert "DELETE FROM upload_files" in delete_upload_call_sql
assert "WHERE id IN :upload_file_ids" in delete_upload_call_sql
# Third call should be DELETE workflow_draft_variable_files
delete_variable_files_call_sql = str(actual_calls[2][0][0])
delete_variable_files_call_sql = " ".join(str(actual_calls[2][0][0]).split())
assert "DELETE FROM workflow_draft_variable_files" in delete_variable_files_call_sql
assert "WHERE id IN :file_ids" in delete_variable_files_call_sql