mirror of
https://github.com/langgenius/dify.git
synced 2026-01-20 12:09:27 +08:00
Compare commits
181 Commits
feat/colla
...
deploy/dev
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c22d22835 | |||
| 008a5f361d | |||
| 08caa4fce3 | |||
| 5293fbe8ba | |||
| ed555c5fe7 | |||
| 22974ea6b0 | |||
| 754b01366a | |||
| 8af626092e | |||
| 49b3bad26b | |||
| 50616c25d4 | |||
| 3b4b5b332c | |||
| 62c3f14570 | |||
| 41c3b1c57c | |||
| 994357d8b5 | |||
| 5fb9fe3c94 | |||
| 4fb08ae7d2 | |||
| 7481762acb | |||
| fcb2fe55e7 | |||
| a0aa8cdb45 | |||
| ae8618877b | |||
| 1c55602445 | |||
| a3f1220d23 | |||
| 4d7384731e | |||
| d62e16b9bb | |||
| 13f2a43ccc | |||
| 553dd3266b | |||
| 5b0590d58e | |||
| d97f2df85c | |||
| d3c09f16a9 | |||
| fde8efa4a2 | |||
| 5f6d1297b0 | |||
| 869e70964f | |||
| 1f313eb15c | |||
| f02adc26e5 | |||
| 73027eab0a | |||
| 74245fea8e | |||
| 5bc4bba668 | |||
| 1126a2aa95 | |||
| 2107a3c32c | |||
| 22d0c55363 | |||
| 7c3ce7b1e6 | |||
| f4d20a02aa | |||
| 7eb65b07c8 | |||
| 830a7fb034 | |||
| 9b7e807690 | |||
| af86f8de6f | |||
| ec78676949 | |||
| 01a7dbcee8 | |||
| 4fe8d2491e | |||
| 76da8b4ff3 | |||
| 25bfc1cc3b | |||
| 5c2ae922bc | |||
| a92df530da | |||
| 13eec13a14 | |||
| 431936beb9 | |||
| 163540bf4a | |||
| 221130b448 | |||
| b1eb265fa5 | |||
| c2a0950660 | |||
| bfe98009fd | |||
| ea1704d211 | |||
| 3ed0937734 | |||
| 1fcf6e4943 | |||
| f4a7efde3d | |||
| 38d4f0fd96 | |||
| ec4f885dad | |||
| 3781c2a025 | |||
| 3782f17dc7 | |||
| 29698aeed2 | |||
| 15ff8efb15 | |||
| 407e1c8276 | |||
| e368825c21 | |||
| 8dad6b6a6d | |||
| 2f54965a72 | |||
| a1a3fa0283 | |||
| ff7344f3d3 | |||
| bcd33be22a | |||
| 0fb339ca4f | |||
| c1871e67aa | |||
| f711f9a317 | |||
| 9ff3310cb6 | |||
| b6bdcc7052 | |||
| 67b0771081 | |||
| 9a07488da9 | |||
| ef043c6906 | |||
| ab814e3eac | |||
| a0e1eeb3f1 | |||
| b1ebeb67a7 | |||
| 082179f70f | |||
| 8786ebdbca | |||
| b49a4eab62 | |||
| 0a7b59f500 | |||
| c264d9152f | |||
| 3bf9d898c0 | |||
| a7f2849e74 | |||
| 0957ece92f | |||
| 949bf38d3c | |||
| 7bafb7f959 | |||
| 9735f55ca4 | |||
| 4c1f9b949b | |||
| 0af0c94dde | |||
| 8e4f0640cc | |||
| 1f513e3b43 | |||
| aa0841e2a8 | |||
| b6a1562357 | |||
| bee0797401 | |||
| e085f39c13 | |||
| 344844d3e0 | |||
| 6e9f82491d | |||
| 372b1c3db8 | |||
| 58d305dbed | |||
| 0360a0416b | |||
| 72282b6e8f | |||
| 8391884c4e | |||
| b018f2b0a0 | |||
| ab56b4a818 | |||
| 61ebc756aa | |||
| 4bea38042a | |||
| 337abc536b | |||
| cc02b78aca | |||
| 18f2d24f8e | |||
| 0c7b9a462f | |||
| 4dd5580854 | |||
| 440bd825d8 | |||
| d2379c38bd | |||
| cbc55c577b | |||
| 8e962d15d1 | |||
| b07c766551 | |||
| 9e3dd69277 | |||
| db9e5665c2 | |||
| cad77ce0bf | |||
| 6f4518ebf7 | |||
| a8f5748dee | |||
| 738d3001be | |||
| df4e32aaa0 | |||
| a25e37a96d | |||
| f156b46705 | |||
| 3b64e118d0 | |||
| 566cd20849 | |||
| df76527f29 | |||
| 53a80a5dbe | |||
| 1507792a0c | |||
| 00b9bbff75 | |||
| e1f8b4b387 | |||
| 1539d86f7d | |||
| 67bb14d3ee | |||
| 5653309080 | |||
| 0f52b34b61 | |||
| 75e35857c1 | |||
| 4f81be70e3 | |||
| 1d4d627d05 | |||
| 2357234f39 | |||
| a3f7d8f996 | |||
| 56f12e70c1 | |||
| b14afda160 | |||
| 44b4948972 | |||
| 487eac3b91 | |||
| 84b2913cd9 | |||
| 176d810c8d | |||
| 9e66564526 | |||
| 781a9a56cd | |||
| 93be1219eb | |||
| 3276d6429d | |||
| 50072a63ae | |||
| 1ab7e1cba8 | |||
| b0aef35c63 | |||
| ac351b700c | |||
| d1e5d30ea9 | |||
| c73e84d992 | |||
| 5f0bd5119a | |||
| 8353352bda | |||
| 73845cbec5 | |||
| c2f94e9e8a | |||
| e54efda36f | |||
| d4bd19f6d8 | |||
| 4decbbbf18 | |||
| b15867f92e | |||
| a5e5fbc6e0 | |||
| 1b1471b6d8 | |||
| 5280bffde2 | |||
| db0fc94b39 |
@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
|
||||
# The time in seconds after the signature is rejected
|
||||
FILES_ACCESS_TIMEOUT=300
|
||||
|
||||
# Collaboration mode toggle
|
||||
ENABLE_COLLABORATION_MODE=false
|
||||
|
||||
# Access token expiration time in minutes
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES=60
|
||||
|
||||
|
||||
@ -1,52 +0,0 @@
|
||||
## Purpose
|
||||
|
||||
`api/controllers/console/datasets/datasets_document.py` contains the console (authenticated) APIs for managing dataset documents (list/create/update/delete, processing controls, estimates, etc.).
|
||||
|
||||
## Storage model (uploaded files)
|
||||
|
||||
- For local file uploads into a knowledge base, the binary is stored via `extensions.ext_storage.storage` under the key:
|
||||
- `upload_files/<tenant_id>/<uuid>.<ext>`
|
||||
- File metadata is stored in the `upload_files` table (`UploadFile` model), keyed by `UploadFile.id`.
|
||||
- Dataset `Document` records reference the uploaded file via:
|
||||
- `Document.data_source_info.upload_file_id`
|
||||
|
||||
## Download endpoint
|
||||
|
||||
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
|
||||
|
||||
- Only supported when `Document.data_source_type == "upload_file"`.
|
||||
- Performs dataset permission + tenant checks via `DocumentResource.get_document(...)`.
|
||||
- Delegates `Document -> UploadFile` validation and signed URL generation to `DocumentService.get_document_download_url(...)`.
|
||||
- Applies `cloud_edition_billing_rate_limit_check("knowledge")` to match other KB operations.
|
||||
- Response body is **only**: `{ "url": "<signed-url>" }`.
|
||||
|
||||
- `POST /datasets/<dataset_id>/documents/download-zip`
|
||||
|
||||
- Accepts `{ "document_ids": ["..."] }` (upload-file only).
|
||||
- Returns `application/zip` as a single attachment download.
|
||||
- Rationale: browsers often block multiple automatic downloads; a ZIP avoids that limitation.
|
||||
- Applies `cloud_edition_billing_rate_limit_check("knowledge")`.
|
||||
- Delegates dataset permission checks, document/upload-file validation, and download-name generation to
|
||||
`DocumentService.prepare_document_batch_download_zip(...)` before streaming the ZIP.
|
||||
|
||||
## Verification plan
|
||||
|
||||
- Upload a document from a local file into a dataset.
|
||||
- Call the download endpoint and confirm it returns a signed URL.
|
||||
- Open the URL and confirm:
|
||||
- Response headers force download (`Content-Disposition`), and
|
||||
- Downloaded bytes match the uploaded file.
|
||||
- Select multiple uploaded-file documents and download as ZIP; confirm all selected files exist in the archive.
|
||||
|
||||
## Shared helper
|
||||
|
||||
- `DocumentService.get_document_download_url(document)` resolves the `UploadFile` and signs a download URL.
|
||||
- `DocumentService.prepare_document_batch_download_zip(...)` performs dataset permission checks, batches
|
||||
document + upload file lookups, preserves request order, and generates the client-visible ZIP filename.
|
||||
- Internal helpers now live in `DocumentService` (`_get_upload_file_id_for_upload_file_document(...)`,
|
||||
`_get_upload_file_for_upload_file_document(...)`, `_get_upload_files_by_document_id_for_zip_download(...)`).
|
||||
- ZIP packing is handled by `FileService.build_upload_files_zip_tempfile(...)`, which also:
|
||||
- sanitizes entry names to avoid path traversal, and
|
||||
- deduplicates names while preserving extensions (e.g., `doc.txt` → `doc (1).txt`).
|
||||
Streaming the response and deferring cleanup is handled by the route via `send_file(path, ...)` + `ExitStack` +
|
||||
`response.call_on_close(...)` (the file is deleted when the response is closed).
|
||||
@ -1,18 +0,0 @@
|
||||
## Purpose
|
||||
|
||||
`api/services/dataset_service.py` hosts dataset/document service logic used by console and API controllers.
|
||||
|
||||
## Batch document operations
|
||||
|
||||
- Batch document workflows should avoid N+1 database queries by using set-based lookups.
|
||||
- Tenant checks must be enforced consistently across dataset/document operations.
|
||||
- `DocumentService.get_documents_by_ids(...)` fetches documents for a dataset using `id.in_(...)`.
|
||||
- `FileService.get_upload_files_by_ids(...)` performs tenant-scoped batch lookup for `UploadFile` (dedupes ids with `set(...)`).
|
||||
- `DocumentService.get_document_download_url(...)` and `prepare_document_batch_download_zip(...)` handle
|
||||
dataset/document permission checks plus `Document -> UploadFile` validation for download endpoints.
|
||||
|
||||
## Verification plan
|
||||
|
||||
- Exercise document list and download endpoints that use the service helpers.
|
||||
- Confirm batch download uses constant query count for documents + upload files.
|
||||
- Request a ZIP with a missing document id and confirm a 404 is returned.
|
||||
@ -1,35 +0,0 @@
|
||||
## Purpose
|
||||
|
||||
`api/services/file_service.py` owns business logic around `UploadFile` objects: upload validation, storage persistence,
|
||||
previews/generators, and deletion.
|
||||
|
||||
## Key invariants
|
||||
|
||||
- All storage I/O goes through `extensions.ext_storage.storage`.
|
||||
- Uploaded file keys follow: `upload_files/<tenant_id>/<uuid>.<ext>`.
|
||||
- Upload validation is enforced in `FileService.upload_file(...)` (blocked extensions, size limits, dataset-only types).
|
||||
|
||||
## Batch lookup helpers
|
||||
|
||||
- `FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)` is the canonical tenant-scoped batch loader for
|
||||
`UploadFile`.
|
||||
|
||||
## Dataset document download helpers
|
||||
|
||||
The dataset document download/ZIP endpoints now delegate “Document → UploadFile” validation and permission checks to
|
||||
`DocumentService` (`api/services/dataset_service.py`). `FileService` stays focused on generic `UploadFile` operations
|
||||
(uploading, previews, deletion), plus generic ZIP serving.
|
||||
|
||||
### ZIP serving
|
||||
|
||||
- `FileService.build_upload_files_zip_tempfile(...)` builds a ZIP from `UploadFile` objects and yields a seeked
|
||||
tempfile **path** so callers can stream it (e.g., `send_file(path, ...)`) without hitting "read of closed file"
|
||||
issues from file-handle lifecycle during streamed responses.
|
||||
- Flask `send_file(...)` and the `ExitStack`/`call_on_close(...)` cleanup pattern are handled in the route layer.
|
||||
|
||||
## Verification plan
|
||||
|
||||
- Unit: `api/tests/unit_tests/controllers/console/datasets/test_datasets_document_download.py`
|
||||
- Verify signed URL generation for upload-file documents and ZIP download behavior for multiple documents.
|
||||
- Unit: `api/tests/unit_tests/services/test_file_service_zip_and_lookup.py`
|
||||
- Verify ZIP packing produces a valid, openable archive and preserves file content.
|
||||
@ -1,28 +0,0 @@
|
||||
## Purpose
|
||||
|
||||
Unit tests for the console dataset document download endpoint:
|
||||
|
||||
- `GET /datasets/<dataset_id>/documents/<document_id>/download`
|
||||
|
||||
## Testing approach
|
||||
|
||||
- Uses `Flask.test_request_context()` and calls the `Resource.get(...)` method directly.
|
||||
- Monkeypatches console decorators (`login_required`, `setup_required`, rate limit) to no-ops to keep the test focused.
|
||||
- Mocks:
|
||||
- `DatasetService.get_dataset` / `check_dataset_permission`
|
||||
- `DocumentService.get_document` for single-file download tests
|
||||
- `DocumentService.get_documents_by_ids` + `FileService.get_upload_files_by_ids` for ZIP download tests
|
||||
- `FileService.get_upload_files_by_ids` for `UploadFile` lookups in single-file tests
|
||||
- `services.dataset_service.file_helpers.get_signed_file_url` to return a deterministic URL
|
||||
- Document mocks include `id` fields so batch lookups can map documents by id.
|
||||
|
||||
## Covered cases
|
||||
|
||||
- Success returns `{ "url": "<signed>" }` for upload-file documents.
|
||||
- 404 when document is not `upload_file`.
|
||||
- 404 when `upload_file_id` is missing.
|
||||
- 404 when referenced `UploadFile` row does not exist.
|
||||
- 403 when document tenant does not match current tenant.
|
||||
- Batch ZIP download returns `application/zip` for upload-file documents.
|
||||
- Batch ZIP download rejects non-upload-file documents.
|
||||
- Batch ZIP download uses a random `.zip` attachment name (`download_name`), so tests only assert the suffix.
|
||||
@ -1,18 +0,0 @@
|
||||
## Purpose
|
||||
|
||||
Unit tests for `api/services/file_service.py` helper methods that are not covered by higher-level controller tests.
|
||||
|
||||
## What’s covered
|
||||
|
||||
- `FileService.build_upload_files_zip_tempfile(...)`
|
||||
- ZIP entry name sanitization (no directory components / traversal)
|
||||
- name deduplication while preserving extensions
|
||||
- writing streamed bytes from `storage.load(...)` into ZIP entries
|
||||
- yields a tempfile path so callers can open/stream the ZIP without holding a live file handle
|
||||
- `FileService.get_upload_files_by_ids(...)`
|
||||
- returns `{}` for empty id lists
|
||||
- returns an id-keyed mapping for non-empty lists
|
||||
|
||||
## Notes
|
||||
|
||||
- These tests intentionally stub `storage.load` and `db.session.scalars(...).all()` to avoid needing a real DB/storage.
|
||||
19
api/app.py
19
api/app.py
@ -1,4 +1,3 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
@ -9,15 +8,10 @@ def is_db_command() -> bool:
|
||||
|
||||
|
||||
# create app
|
||||
flask_app = None
|
||||
socketio_app = None
|
||||
|
||||
if is_db_command():
|
||||
from app_factory import create_migrations_app
|
||||
|
||||
app = create_migrations_app()
|
||||
socketio_app = app
|
||||
flask_app = app
|
||||
else:
|
||||
# Gunicorn and Celery handle monkey patching automatically in production by
|
||||
# specifying the `gevent` worker class. Manual monkey patching is not required here.
|
||||
@ -28,15 +22,8 @@ else:
|
||||
|
||||
from app_factory import create_app
|
||||
|
||||
socketio_app, flask_app = create_app()
|
||||
app = flask_app
|
||||
celery = flask_app.extensions["celery"]
|
||||
app = create_app()
|
||||
celery = app.extensions["celery"]
|
||||
|
||||
if __name__ == "__main__":
|
||||
from gevent import pywsgi
|
||||
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
port = int(os.environ.get("PORT", 5001))
|
||||
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
|
||||
server.serve_forever()
|
||||
app.run(host="0.0.0.0", port=5001)
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
@ -9,7 +8,6 @@ from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_socketio import sio
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -62,23 +60,17 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
return dify_app
|
||||
|
||||
|
||||
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
|
||||
def create_app() -> DifyApp:
|
||||
start_time = time.perf_counter()
|
||||
app = create_flask_app_with_configs()
|
||||
initialize_extensions(app)
|
||||
|
||||
sio.app = app
|
||||
socketio_app = socketio.WSGIApp(sio, app)
|
||||
|
||||
end_time = time.perf_counter()
|
||||
if dify_config.DEBUG:
|
||||
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
|
||||
return socketio_app, app
|
||||
return app
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
# Initialize Flask context capture for workflow execution
|
||||
from context.flask_app_context import init_flask_context
|
||||
from extensions import (
|
||||
ext_app_metrics,
|
||||
ext_blueprints,
|
||||
@ -108,8 +100,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
)
|
||||
|
||||
init_flask_context()
|
||||
|
||||
extensions = [
|
||||
ext_timezone,
|
||||
ext_logging,
|
||||
|
||||
@ -862,27 +862,8 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
|
||||
|
||||
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
|
||||
@click.option(
|
||||
"--before-days",
|
||||
"--days",
|
||||
default=30,
|
||||
show_default=True,
|
||||
type=click.IntRange(min=0),
|
||||
help="Delete workflow runs created before N days ago.",
|
||||
)
|
||||
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
|
||||
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
@ -901,10 +882,8 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
before_days: int,
|
||||
days: int,
|
||||
batch_size: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
@ -915,24 +894,11 @@ def clean_workflow_runs(
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
raise click.UsageError("Choose either day offsets or explicit dates, not both.")
|
||||
if from_days_ago <= to_days_ago:
|
||||
raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=before_days,
|
||||
days=days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
|
||||
@ -1219,13 +1219,6 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@ -1340,7 +1333,6 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
Core Context - Framework-agnostic context management.
|
||||
|
||||
This module provides context management that is independent of any specific
|
||||
web framework. Framework-specific implementations register their context
|
||||
capture functions at application initialization time.
|
||||
|
||||
This ensures the workflow layer remains completely decoupled from Flask
|
||||
or any other web framework.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Callable
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
)
|
||||
|
||||
# Global capturer function - set by framework-specific modules
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""
|
||||
Register a context capture function.
|
||||
|
||||
This should be called by framework-specific modules (e.g., Flask)
|
||||
during application initialization.
|
||||
|
||||
Args:
|
||||
capturer: Function that captures current context and returns IExecutionContext
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context.
|
||||
|
||||
This function uses the registered context capturer. If no capturer
|
||||
is registered, it returns a minimal context with only contextvars
|
||||
(suitable for non-framework environments like tests or standalone scripts).
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
if _capturer is None:
|
||||
# No framework registered - return minimal context
|
||||
return ExecutionContext(
|
||||
app_context=NullAppContext(),
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
return _capturer()
|
||||
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""
|
||||
Reset the context capturer.
|
||||
|
||||
This is primarily useful for testing to ensure a clean state.
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"capture_current_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
@ -1,198 +0,0 @@
|
||||
"""
|
||||
Flask App Context - Flask implementation of AppContext interface.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
from context import register_context_capturer
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
IExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskAppContext(AppContext):
|
||||
"""
|
||||
Flask implementation of AppContext.
|
||||
|
||||
This adapts Flask's app context to the AppContext interface.
|
||||
"""
|
||||
|
||||
def __init__(self, flask_app: Flask) -> None:
|
||||
"""
|
||||
Initialize Flask app context.
|
||||
|
||||
Args:
|
||||
flask_app: The Flask application instance
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
yield
|
||||
|
||||
@property
|
||||
def flask_app(self) -> Flask:
|
||||
"""Get the underlying Flask app instance."""
|
||||
return self._flask_app
|
||||
|
||||
|
||||
def capture_flask_context(user: Any = None) -> IExecutionContext:
|
||||
"""
|
||||
Capture current Flask execution context.
|
||||
|
||||
This function captures the Flask app context and contextvars from the
|
||||
current environment. It should be called from within a Flask request or
|
||||
app context.
|
||||
|
||||
Args:
|
||||
user: Optional user object to include in context
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured Flask context
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called outside Flask context
|
||||
"""
|
||||
# Get Flask app instance
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Save current user if available
|
||||
saved_user = user
|
||||
if saved_user is None:
|
||||
# Check for user in g (flask-login)
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Capture contextvars
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
return FlaskExecutionContext(
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
user=saved_user,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskExecutionContext:
|
||||
"""
|
||||
Flask-specific execution context.
|
||||
|
||||
This is a specialized version of ExecutionContext that includes Flask app
|
||||
context. It provides the same interface as ExecutionContext but with
|
||||
Flask-specific implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Flask execution context.
|
||||
|
||||
Args:
|
||||
flask_app: Flask application instance
|
||||
context_vars: Python contextvars
|
||||
user: Optional user object
|
||||
"""
|
||||
self._app_context = FlaskAppContext(flask_app)
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
"""Get Flask app context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
yield
|
||||
|
||||
|
||||
def init_flask_context() -> None:
|
||||
"""
|
||||
Initialize Flask context capture by registering the capturer.
|
||||
|
||||
This function should be called during Flask application initialization
|
||||
to register the Flask-specific context capturer with the core context module.
|
||||
|
||||
Example:
|
||||
app = Flask(__name__)
|
||||
init_flask_context() # Register Flask context capturer
|
||||
|
||||
Note:
|
||||
This function does not need the app instance as it uses Flask's
|
||||
`current_app` to get the app when capturing context.
|
||||
"""
|
||||
register_context_capturer(capture_flask_context)
|
||||
@ -63,7 +63,6 @@ from .app import (
|
||||
statistic,
|
||||
workflow,
|
||||
workflow_app_log,
|
||||
workflow_comment,
|
||||
workflow_draft_variable,
|
||||
workflow_run,
|
||||
workflow_statistic,
|
||||
@ -113,7 +112,6 @@ from .explore import (
|
||||
recommended_app,
|
||||
saved_message,
|
||||
)
|
||||
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
@ -205,7 +203,6 @@ __all__ = [
|
||||
"website",
|
||||
"workflow",
|
||||
"workflow_app_log",
|
||||
"workflow_comment",
|
||||
"workflow_draft_variable",
|
||||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import re
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
@ -67,6 +68,48 @@ class AppListQuery(BaseModel):
|
||||
raise ValueError("Invalid UUID format in tag_ids.") from exc
|
||||
|
||||
|
||||
# XSS prevention: patterns that could lead to XSS attacks
|
||||
# Includes: script tags, iframe tags, javascript: protocol, SVG with onload, etc.
|
||||
_XSS_PATTERNS = [
|
||||
r"<script[^>]*>.*?</script>", # Script tags
|
||||
r"<iframe\b[^>]*?(?:/>|>.*?</iframe>)", # Iframe tags (including self-closing)
|
||||
r"javascript:", # JavaScript protocol
|
||||
r"<svg[^>]*?\s+onload\s*=[^>]*>", # SVG with onload handler (attribute-aware, flexible whitespace)
|
||||
r"<.*?on\s*\w+\s*=", # Event handlers like onclick, onerror, etc.
|
||||
r"<object\b[^>]*(?:\s*/>|>.*?</object\s*>)", # Object tags (opening tag)
|
||||
r"<embed[^>]*>", # Embed tags (self-closing)
|
||||
r"<link[^>]*>", # Link tags with javascript
|
||||
]
|
||||
|
||||
|
||||
def _validate_xss_safe(value: str | None, field_name: str = "Field") -> str | None:
|
||||
"""
|
||||
Validate that a string value doesn't contain potential XSS payloads.
|
||||
|
||||
Args:
|
||||
value: The string value to validate
|
||||
field_name: Name of the field for error messages
|
||||
|
||||
Returns:
|
||||
The original value if safe
|
||||
|
||||
Raises:
|
||||
ValueError: If the value contains XSS patterns
|
||||
"""
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
value_lower = value.lower()
|
||||
for pattern in _XSS_PATTERNS:
|
||||
if re.search(pattern, value_lower, re.DOTALL | re.IGNORECASE):
|
||||
raise ValueError(
|
||||
f"{field_name} contains invalid characters or patterns. "
|
||||
"HTML tags, JavaScript, and other potentially dangerous content are not allowed."
|
||||
)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
@ -75,6 +118,11 @@ class CreateAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class UpdateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
@ -85,6 +133,11 @@ class UpdateAppPayload(BaseModel):
|
||||
use_icon_as_answer_icon: bool | None = Field(default=None, description="Use icon as answer icon")
|
||||
max_active_requests: int | None = Field(default=None, description="Maximum active requests")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class CopyAppPayload(BaseModel):
|
||||
name: str | None = Field(default=None, description="Name for the copied app")
|
||||
@ -93,6 +146,11 @@ class CopyAppPayload(BaseModel):
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@field_validator("name", "description", mode="before")
|
||||
@classmethod
|
||||
def validate_xss_safe(cls, value: str | None, info) -> str | None:
|
||||
return _validate_xss_safe(value, info.field_name)
|
||||
|
||||
|
||||
class AppExportQuery(BaseModel):
|
||||
include_secret: bool = Field(default=False, description="Include secrets in export")
|
||||
|
||||
@ -32,10 +32,8 @@ from core.trigger.debug.event_selectors import (
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import file_factory, variable_factory
|
||||
from fields.member_fields import simple_account_fields
|
||||
from fields.online_user_fields import online_user_list_fields
|
||||
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
|
||||
from fields.workflow_run_fields import workflow_run_node_execution_fields
|
||||
from libs import helper
|
||||
@ -45,7 +43,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
from models.workflow import Workflow
|
||||
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
@ -183,14 +180,6 @@ class WorkflowUpdatePayload(BaseModel):
|
||||
marked_comment: str | None = Field(default=None, max_length=100)
|
||||
|
||||
|
||||
class WorkflowFeaturesPayload(BaseModel):
|
||||
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
|
||||
|
||||
|
||||
class WorkflowOnlineUsersQuery(BaseModel):
|
||||
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
|
||||
|
||||
|
||||
class DraftWorkflowTriggerRunPayload(BaseModel):
|
||||
node_id: str
|
||||
|
||||
@ -214,8 +203,6 @@ reg(DefaultBlockConfigQuery)
|
||||
reg(ConvertToWorkflowPayload)
|
||||
reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(WorkflowFeaturesPayload)
|
||||
reg(WorkflowOnlineUsersQuery)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
|
||||
@ -331,7 +318,6 @@ class DraftWorkflowApi(Resource):
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
force_upload=args.get("force_upload", False),
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
@ -805,31 +791,6 @@ class ConvertToWorkflowApi(Resource):
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
|
||||
class WorkflowFeaturesApi(Resource):
|
||||
"""Update draft workflow features."""
|
||||
|
||||
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
|
||||
@console_ns.doc("update_workflow_features")
|
||||
@console_ns.doc(description="Update draft workflow features")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Workflow features updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
|
||||
features = args.features
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows")
|
||||
class PublishedAllWorkflowApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
|
||||
@ -1205,32 +1166,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/workflows/online-users")
|
||||
class WorkflowOnlineUsersApi(Resource):
|
||||
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
|
||||
@console_ns.doc("get_workflow_online_users")
|
||||
@console_ns.doc(description="Get workflow online users")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(online_user_list_fields)
|
||||
def get(self):
|
||||
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
|
||||
|
||||
results = []
|
||||
for workflow_id in workflow_ids:
|
||||
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
|
||||
|
||||
users = []
|
||||
for _, user_info_json in users_json.items():
|
||||
try:
|
||||
users.append(json.loads(user_info_json))
|
||||
except Exception:
|
||||
continue
|
||||
results.append({"workflow_id": workflow_id, "users": users})
|
||||
|
||||
return {"data": results}
|
||||
|
||||
@ -1,317 +0,0 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import account_with_role_fields
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.model(
|
||||
"WorkflowCommentMentionUsers",
|
||||
{"users": fields.List(fields.Nested(account_with_role_fields))},
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_mention_users_model)
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
return {"users": members}
|
||||
@ -21,9 +21,9 @@ from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from factories.file_factory import build_from_mapping, build_from_mappings
|
||||
from libs.login import current_user, login_required
|
||||
from factories.variable_factory import build_segment_with_type
|
||||
from libs.login import login_required
|
||||
from models import App, AppMode
|
||||
from models.workflow import WorkflowDraftVariable
|
||||
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||
@ -43,16 +43,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
|
||||
value: Any | None = Field(default=None, description="Variable value")
|
||||
|
||||
|
||||
class ConversationVariableUpdatePayload(BaseModel):
|
||||
conversation_variables: list[dict[str, Any]] = Field(
|
||||
..., description="Conversation variables for the draft workflow"
|
||||
)
|
||||
|
||||
|
||||
class EnvironmentVariableUpdatePayload(BaseModel):
|
||||
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
WorkflowDraftVariableListQuery.__name__,
|
||||
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
@ -61,14 +51,6 @@ console_ns.schema_model(
|
||||
WorkflowDraftVariableUpdatePayload.__name__,
|
||||
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ConversationVariableUpdatePayload.__name__,
|
||||
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
console_ns.schema_model(
|
||||
EnvironmentVariableUpdatePayload.__name__,
|
||||
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
def _convert_values_to_json_serializable_object(value: Segment):
|
||||
@ -401,7 +383,7 @@ class VariableApi(Resource):
|
||||
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
|
||||
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
|
||||
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
|
||||
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
|
||||
new_value = build_segment_with_type(variable.value_type, raw_value)
|
||||
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||
db.session.commit()
|
||||
return variable
|
||||
@ -494,34 +476,6 @@ class ConversationVariableCollectionApi(Resource):
|
||||
db.session.commit()
|
||||
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||
|
||||
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_conversation_variables")
|
||||
@console_ns.doc(description="Update conversation variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Conversation variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=AppMode.ADVANCED_CHAT)
|
||||
def post(self, app_model: App):
|
||||
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
conversation_variables_list = payload.conversation_variables
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_conversation_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||
class SystemVariableCollectionApi(Resource):
|
||||
@ -573,31 +527,3 @@ class EnvironmentVariableCollectionApi(Resource):
|
||||
)
|
||||
|
||||
return {"items": env_vars_list}
|
||||
|
||||
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
|
||||
@console_ns.doc("update_environment_variables")
|
||||
@console_ns.doc(description="Update environment variables for workflow draft")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Environment variables updated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
def post(self, app_model: App):
|
||||
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
|
||||
environment_variables_list = payload.environment_variables
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
|
||||
workflow_service.update_draft_workflow_environment_variables(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -69,13 +69,6 @@ class ActivateCheckApi(Resource):
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
|
||||
# Check workspace permission
|
||||
if tenant:
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
workspace_name = tenant.name if tenant else None
|
||||
workspace_id = tenant.id if tenant else None
|
||||
invitee_email = data.get("email") if data else None
|
||||
|
||||
@ -146,6 +146,7 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@ -2,12 +2,10 @@ import json
|
||||
import logging
|
||||
from argparse import ArgumentTypeError
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Literal, cast
|
||||
from uuid import UUID
|
||||
from typing import Literal, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request, send_file
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, desc, select
|
||||
@ -41,10 +39,10 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from services.file_service import FileService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -68,9 +66,6 @@ from ..wraps import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# NOTE: Keep constants near the top of the module for discoverability.
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS = 100
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
existing = console_ns.models.get(model_name)
|
||||
@ -110,10 +105,8 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class DocumentBatchDownloadZipPayload(BaseModel):
|
||||
"""Request payload for bulk downloading documents as a zip archive."""
|
||||
|
||||
document_ids: list[UUID] = Field(..., min_length=1, max_length=DOCUMENT_BATCH_DOWNLOAD_ZIP_MAX_DOCS)
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
@ -132,7 +125,7 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
DocumentBatchDownloadZipPayload,
|
||||
GenerateSummaryPayload,
|
||||
)
|
||||
|
||||
|
||||
@ -319,6 +312,89 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
DocumentSegment.document_id.in_(document_ids_need_summary),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map = {}
|
||||
for segment in segments:
|
||||
doc_id = str(segment.document_id)
|
||||
if doc_id not in document_segments_map:
|
||||
document_segments_map[doc_id] = []
|
||||
document_segments_map[doc_id].append(segment.id)
|
||||
|
||||
# Get all summary records for these segments
|
||||
all_segment_ids = [seg.id for seg in segments]
|
||||
summaries = {}
|
||||
if all_segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
summaries = {summary.chunk_id: summary.status for summary in summary_records}
|
||||
|
||||
# Calculate summary_index_status for each document
|
||||
for doc_id in document_ids_need_summary:
|
||||
segment_ids = document_segments_map.get(doc_id, [])
|
||||
if not segment_ids:
|
||||
# No segments, status is None (not started)
|
||||
summary_status_map[doc_id] = None
|
||||
continue
|
||||
|
||||
# Count summary statuses for this document's segments
|
||||
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
|
||||
for segment_id in segment_ids:
|
||||
status = summaries.get(segment_id, "not_started")
|
||||
if status in status_counts:
|
||||
status_counts[status] += 1
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
|
||||
generating_count = status_counts["generating"]
|
||||
|
||||
# Determine overall status:
|
||||
# - "SUMMARIZING" only when task is queued and at least one summary is generating
|
||||
# - None (empty) for all other cases (not queued, all completed/error)
|
||||
if generating_count > 0:
|
||||
# Task is queued and at least one summary is still generating
|
||||
summary_status_map[doc_id] = "SUMMARIZING"
|
||||
else:
|
||||
# Task not queued yet, or all summaries are completed/error (task finished)
|
||||
summary_status_map[doc_id] = None
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
# Get status from map, default to None (not queued yet)
|
||||
document.summary_index_status = summary_status_map.get(str(document.id))
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -804,6 +880,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -839,6 +916,7 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -866,62 +944,6 @@ class DocumentApi(DocumentResource):
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/download")
|
||||
class DocumentDownloadApi(DocumentResource):
|
||||
"""Return a signed download URL for a dataset document's original uploaded file."""
|
||||
|
||||
@console_ns.doc("get_dataset_document_download_url")
|
||||
@console_ns.doc(description="Get a signed download URL for a dataset document's original uploaded file")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def get(self, dataset_id: str, document_id: str) -> dict[str, Any]:
|
||||
# Reuse the shared permission/tenant checks implemented in DocumentResource.
|
||||
document = self.get_document(str(dataset_id), str(document_id))
|
||||
return {"url": DocumentService.get_document_download_url(document)}
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/download-zip")
|
||||
class DocumentBatchDownloadZipApi(DocumentResource):
|
||||
"""Download multiple uploaded-file documents as a single ZIP (avoids browser multi-download limits)."""
|
||||
|
||||
@console_ns.doc("download_dataset_documents_as_zip")
|
||||
@console_ns.doc(description="Download selected dataset documents as a single ZIP archive (upload-file only)")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
@console_ns.expect(console_ns.models[DocumentBatchDownloadZipPayload.__name__])
|
||||
def post(self, dataset_id: str):
|
||||
"""Stream a ZIP archive containing the requested uploaded documents."""
|
||||
# Parse and validate request payload.
|
||||
payload = DocumentBatchDownloadZipPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_ids: list[str] = [str(document_id) for document_id in payload.document_ids]
|
||||
upload_files, download_name = DocumentService.prepare_document_batch_download_zip(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
tenant_id=current_tenant_id,
|
||||
current_user=current_user,
|
||||
)
|
||||
|
||||
# Delegate ZIP packing to FileService, but keep Flask response+cleanup in the route.
|
||||
with ExitStack() as stack:
|
||||
zip_path = stack.enter_context(FileService.build_upload_files_zip_tempfile(upload_files=upload_files))
|
||||
response = send_file(
|
||||
zip_path,
|
||||
mimetype="application/zip",
|
||||
as_attachment=True,
|
||||
download_name=download_name,
|
||||
)
|
||||
cleanup = stack.pop_all()
|
||||
response.call_on_close(cleanup.close)
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/<string:action>")
|
||||
class DocumentProcessingApi(DocumentResource):
|
||||
@console_ns.doc("update_document_processing")
|
||||
@ -1262,3 +1284,216 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
raise ValueError("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get document
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get all segments for this document
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_segments = len(segments)
|
||||
|
||||
# Get all summary records for these segments
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries = []
|
||||
if segment_ids:
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create a mapping of chunk_id to summary
|
||||
summary_map = {summary.chunk_id: summary for summary in summaries}
|
||||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": summary.status,
|
||||
"summary_preview": (
|
||||
summary.summary_content[:100] + "..."
|
||||
if summary.summary_content and len(summary.summary_content) > 100
|
||||
else summary.summary_content
|
||||
),
|
||||
"error": summary.error,
|
||||
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
|
||||
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"summary_status": status_counts,
|
||||
"summaries": summary_list,
|
||||
}, 200
|
||||
|
||||
@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
@ -41,6 +41,23 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -63,6 +80,7 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -180,8 +198,32 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# Only include enabled summaries
|
||||
summaries = {
|
||||
summary.chunk_id: summary.summary_content for summary in summary_records if summary.enabled is True
|
||||
}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"data": segments_with_summary,
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -327,7 +369,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -389,10 +431,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,6 +1,13 @@
|
||||
from flask_restx import Resource
|
||||
from flask_restx import Resource, fields
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -14,13 +21,45 @@ from ..wraps import (
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
@ -1 +0,0 @@
|
||||
|
||||
@ -1,107 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
token = None
|
||||
if auth and isinstance(auth, dict):
|
||||
token = auth.get("token")
|
||||
|
||||
if not token:
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
@ -36,7 +36,6 @@ from controllers.console.wraps import (
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from fields.member_fields import account_fields
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -74,10 +73,6 @@ class AccountAvatarPayload(BaseModel):
|
||||
avatar: str
|
||||
|
||||
|
||||
class AccountAvatarQuery(BaseModel):
|
||||
avatar: str = Field(..., description="Avatar file ID")
|
||||
|
||||
|
||||
class AccountInterfaceLanguagePayload(BaseModel):
|
||||
interface_language: str
|
||||
|
||||
@ -163,7 +158,6 @@ def reg(cls: type[BaseModel]):
|
||||
reg(AccountInitPayload)
|
||||
reg(AccountNamePayload)
|
||||
reg(AccountAvatarPayload)
|
||||
reg(AccountAvatarQuery)
|
||||
reg(AccountInterfaceLanguagePayload)
|
||||
reg(AccountInterfaceThemePayload)
|
||||
reg(AccountTimezonePayload)
|
||||
@ -254,18 +248,6 @@ class AccountNameApi(Resource):
|
||||
|
||||
@console_ns.route("/account/avatar")
|
||||
class AccountAvatarApi(Resource):
|
||||
@console_ns.expect(console_ns.models[AccountAvatarQuery.__name__])
|
||||
@console_ns.doc("get_account_avatar")
|
||||
@console_ns.doc(description="Get account avatar url")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
args = AccountAvatarQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
avatar_url = file_helpers.get_signed_file_url(args.avatar)
|
||||
return {"avatar_url": avatar_url}
|
||||
|
||||
@console_ns.expect(console_ns.models[AccountAvatarPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -107,12 +107,6 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(inviter.current_tenant.id)
|
||||
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -29,7 +28,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -290,31 +288,3 @@ class WorkspaceInfoApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/permission")
|
||||
class WorkspacePermissionApi(Resource):
|
||||
"""Get workspace permissions for the current workspace."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Get workspace permissions from enterprise service
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
|
||||
|
||||
return {
|
||||
"workspace_id": permission.workspace_id,
|
||||
"allow_member_invite": permission.allow_member_invite,
|
||||
"allow_owner_transfer": permission.allow_owner_transfer,
|
||||
}, 200
|
||||
|
||||
@ -286,12 +286,13 @@ def enable_change_email(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# Check both billing/plan level and workspace policy level permissions
|
||||
check_workspace_owner_transfer_permission(current_tenant_id)
|
||||
return view(*args, **kwargs)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -23,7 +23,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -477,7 +476,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:return:
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
|
||||
@ -3,6 +3,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -311,14 +311,18 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
@ -361,6 +365,12 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
||||
@ -72,7 +72,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = response.message.get_text_content()
|
||||
if answer == "":
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
@ -113,9 +113,11 @@ class LLMGenerator:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
prompt_template = PromptTemplateParser(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
|
||||
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
|
||||
prompt = prompt_template.format(
|
||||
{"histories": histories, "format_instructions": format_instructions})
|
||||
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -141,11 +143,13 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
questions = output_parser.parse(text_content) if text_content else []
|
||||
questions = output_parser.parse(
|
||||
text_content) if text_content else []
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
logger.exception(
|
||||
"Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
||||
@ -156,10 +160,12 @@ class LLMGenerator:
|
||||
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
rule_config = {"prompt": "", "variables": [],
|
||||
"opening_statement": "", "error": ""}
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
if no_variable:
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt_generate = prompt_template.format(
|
||||
inputs={
|
||||
@ -190,7 +196,8 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception(
|
||||
"Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -245,7 +252,8 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
parameter_messages = [UserPromptMessage(
|
||||
content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
@ -255,13 +263,15 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
statement_messages = [UserPromptMessage(
|
||||
content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
@ -270,13 +280,15 @@ class LLMGenerator:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content()
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content(
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception(
|
||||
"Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -286,9 +298,11 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
@ -321,7 +335,8 @@ class LLMGenerator:
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get(
|
||||
"name"), code_language
|
||||
)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@ -335,7 +350,8 @@ class LLMGenerator:
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
prompt_messages: list[PromptMessage] = [SystemPromptMessage(
|
||||
content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
# Explicitly use the non-streaming overload
|
||||
result = model_instance.invoke_llm(
|
||||
@ -381,16 +397,19 @@ class LLMGenerator:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
raise ValueError(
|
||||
f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
generated_json_schema = json.dumps(
|
||||
parsed_content, indent=2, ensure_ascii=False)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
@ -398,7 +417,8 @@ class LLMGenerator:
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(
|
||||
Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
@ -446,7 +466,8 @@ class LLMGenerator:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found for the given app model.")
|
||||
last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
|
||||
last_run = workflow_service.get_node_last_run(
|
||||
app_model=app, workflow=workflow, node_id=node_id)
|
||||
try:
|
||||
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
|
||||
except Exception:
|
||||
@ -470,7 +491,8 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
if not raw_agent_log:
|
||||
return []
|
||||
|
||||
@ -518,11 +540,14 @@ class LLMGenerator:
|
||||
ERROR_MESSAGE = "{{#error_message#}}"
|
||||
injected_instruction = instruction
|
||||
if LAST_RUN in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
|
||||
injected_instruction = injected_instruction.replace(
|
||||
LAST_RUN, json.dumps(last_run))
|
||||
if CURRENT in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
|
||||
injected_instruction = injected_instruction.replace(
|
||||
CURRENT, current or "null")
|
||||
if ERROR_MESSAGE in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
|
||||
injected_instruction = injected_instruction.replace(
|
||||
ERROR_MESSAGE, error_message or "null")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@ -560,11 +585,13 @@ class LLMGenerator:
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
raise ValueError(
|
||||
f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace: last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
raise TypeError(
|
||||
f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
|
||||
@ -434,3 +434,20 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
"""Summarize the following content. Extract only the key information and main points. """
|
||||
"""Remove redundant details.
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
"""
|
||||
)
|
||||
@ -320,17 +320,18 @@ class BasePluginClient:
|
||||
case PluginInvokeError.__name__:
|
||||
error_object = json.loads(message)
|
||||
invoke_error_type = error_object.get("error_type")
|
||||
args = error_object.get("args")
|
||||
match invoke_error_type:
|
||||
case InvokeRateLimitError.__name__:
|
||||
raise InvokeRateLimitError(description=error_object.get("message"))
|
||||
raise InvokeRateLimitError(description=args.get("description"))
|
||||
case InvokeAuthorizationError.__name__:
|
||||
raise InvokeAuthorizationError(description=error_object.get("message"))
|
||||
raise InvokeAuthorizationError(description=args.get("description"))
|
||||
case InvokeBadRequestError.__name__:
|
||||
raise InvokeBadRequestError(description=error_object.get("message"))
|
||||
raise InvokeBadRequestError(description=args.get("description"))
|
||||
case InvokeConnectionError.__name__:
|
||||
raise InvokeConnectionError(description=error_object.get("message"))
|
||||
raise InvokeConnectionError(description=args.get("description"))
|
||||
case InvokeServerUnavailableError.__name__:
|
||||
raise InvokeServerUnavailableError(description=error_object.get("message"))
|
||||
raise InvokeServerUnavailableError(description=args.get("description"))
|
||||
case CredentialsValidateFailedError.__name__:
|
||||
raise CredentialsValidateFailedError(error_object.get("message"))
|
||||
case EndpointSetupFailedError.__name__:
|
||||
@ -338,11 +339,11 @@ class BasePluginClient:
|
||||
case TriggerProviderCredentialValidationError.__name__:
|
||||
raise TriggerProviderCredentialValidationError(error_object.get("message"))
|
||||
case TriggerPluginInvokeError.__name__:
|
||||
raise TriggerPluginInvokeError(description=error_object.get("message"))
|
||||
raise TriggerPluginInvokeError(description=error_object.get("description"))
|
||||
case TriggerInvokeError.__name__:
|
||||
raise TriggerInvokeError(error_object.get("message"))
|
||||
case EventIgnoreError.__name__:
|
||||
raise EventIgnoreError(description=error_object.get("message"))
|
||||
raise EventIgnoreError(description=error_object.get("description"))
|
||||
case _:
|
||||
raise PluginInvokeError(description=message)
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
|
||||
@ -389,15 +389,14 @@ class RetrievalService:
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
|
||||
# First pass: collect all document IDs and identify summary documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
@ -408,16 +407,24 @@ class RetrievalService:
|
||||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if original_chunk_id:
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
continue # Skip adding to other lists for summary documents
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
@ -433,6 +440,7 @@ class RetrievalService:
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
@ -447,6 +455,7 @@ class RetrievalService:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
@ -470,6 +479,7 @@ class RetrievalService:
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
@ -481,6 +491,42 @@ class RetrievalService:
|
||||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
# Handle summary documents: query segments by original_chunk_id
|
||||
if summary_segment_ids:
|
||||
summary_segment_ids_list = list(summary_segment_ids)
|
||||
summary_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(summary_segment_ids_list),
|
||||
)
|
||||
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
|
||||
segments.extend(summary_segments)
|
||||
# Add summary segment IDs to segment_ids for summary query
|
||||
for seg in summary_segments:
|
||||
if seg.id not in segment_ids:
|
||||
segment_ids.append(seg.id)
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
@ -493,7 +539,7 @@ class RetrievalService:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
document = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
@ -503,7 +549,7 @@ class RetrievalService:
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
file_document = doc_to_document_map.get(attachment_info["id"])
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
@ -576,9 +622,16 @@ class RetrievalService:
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content,
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
||||
@ -20,3 +20,4 @@ class RetrievalSegments(BaseModel):
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -13,6 +13,7 @@ from urllib.parse import unquote, urlparse
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -45,6 +46,17 @@ class BaseIndexProcessor(ABC):
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
||||
@ -1,9 +1,25 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -17,12 +33,16 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
from models import UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@ -108,6 +128,29 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -227,3 +270,263 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for preview")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
and supports vision models by including images from the segment attachments or text content.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
|
||||
# Get model schema to check if vision is supported
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
|
||||
|
||||
# Extract images if model supports vision
|
||||
image_files = []
|
||||
if supports_vision:
|
||||
# First, try to get images from SegmentAttachmentBinding (preferred method)
|
||||
if segment_id:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
|
||||
|
||||
# If no images from attachments, fall back to extracting from text
|
||||
if not image_files:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
|
||||
|
||||
# Build prompt messages
|
||||
prompt_messages = []
|
||||
|
||||
if image_files:
|
||||
# If we have images, create a UserPromptMessage with both text and images
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
|
||||
# Add images first
|
||||
for file in image_files:
|
||||
try:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
prompt_message_contents.append(file_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
|
||||
continue
|
||||
|
||||
# Add text content
|
||||
if prompt_message_contents: # Only add text if we successfully added images
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
# If image conversion failed, fall back to text-only
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
else:
|
||||
# No images, use simple text prompt
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
result = model_instance.invoke_llm(prompt_messages=prompt_messages, model_parameters={}, stream=False)
|
||||
|
||||
return getattr(result.message, "content", "")
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
|
||||
"""
|
||||
Extract images from markdown text and convert them to File objects.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content that may contain markdown image links
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in the text
|
||||
"""
|
||||
# Extract markdown images using regex pattern
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
images = re.findall(pattern, text)
|
||||
|
||||
if not images:
|
||||
return []
|
||||
|
||||
upload_file_id_list = []
|
||||
|
||||
for image in images:
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
# Tool files are handled differently, skip for now
|
||||
continue
|
||||
|
||||
if not upload_file_id_list:
|
||||
return []
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
for upload_file in upload_files:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
mapping = {
|
||||
"upload_file_id": upload_file.id,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
|
||||
"type": FileType.IMAGE.value,
|
||||
}
|
||||
|
||||
try:
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
|
||||
"""
|
||||
Extract images from SegmentAttachmentBinding table (preferred method).
|
||||
This matches how DatasetRetrieval gets segment attachments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
segment_id: Segment ID to fetch attachments for
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in segment attachments
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query attachments from SegmentAttachmentBinding table
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
SegmentAttachmentBinding.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
|
||||
file_objects = []
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -25,6 +27,9 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@ -135,6 +140,29 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
@ -326,3 +354,57 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
|
||||
Note: For parent-child structure, we only generate summaries for parent chunks.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item (parent chunk)."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview.summary = summary
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for preview")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@ -11,6 +11,7 @@ import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -25,9 +26,10 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -144,6 +146,30 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -212,6 +238,17 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
Note: QA model uses question-answer pairs, which don't require summary generation.
|
||||
"""
|
||||
# QA model doesn't generate summaries, return as-is
|
||||
return preview_texts
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
|
||||
@ -5,6 +5,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -28,21 +29,6 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_resolve_user_from_request() -> Account | EndUser | None:
|
||||
"""
|
||||
Try to resolve user from Flask request context.
|
||||
|
||||
Returns None if not in a request context or if user is not available.
|
||||
"""
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
# Use _get_current_object() to dereference the proxy
|
||||
user = getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
# Check if we got a valid user object
|
||||
if user is not None and hasattr(user, "id"):
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
"""
|
||||
Workflow tool.
|
||||
@ -223,13 +209,21 @@ class WorkflowTool(Tool):
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
# Try to resolve user from request context first
|
||||
user = _try_resolve_user_from_request()
|
||||
if user is not None:
|
||||
return user
|
||||
if has_request_context():
|
||||
return self._resolve_user_from_request()
|
||||
else:
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
# Fall back to database resolution
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
def _resolve_user_from_request(self) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user from Flask request context.
|
||||
"""
|
||||
try:
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
return getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
"""
|
||||
Execution Context - Context management for workflow execution.
|
||||
|
||||
This package provides Flask-independent context management for workflow
|
||||
execution in multi-threaded environments.
|
||||
"""
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"capture_current_context",
|
||||
]
|
||||
@ -1,216 +0,0 @@
|
||||
"""
|
||||
Execution Context - Abstracted context management for workflow execution.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
|
||||
This abstraction allows workflow execution to work with or without Flask
|
||||
by providing a common interface for application context management.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name (e.g., 'db', 'cache')."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IExecutionContext(Protocol):
|
||||
"""
|
||||
Protocol for execution context.
|
||||
|
||||
This protocol defines the interface that all execution contexts must implement,
|
||||
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "IExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
...
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
...
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for workflow execution in worker threads.
|
||||
|
||||
This class encapsulates all context needed for workflow execution:
|
||||
- Application context (Flask app or standalone)
|
||||
- Context variables for Python contextvars
|
||||
- User information (optional)
|
||||
|
||||
It is designed to be serializable and passable to worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_context: AppContext | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize execution context.
|
||||
|
||||
Args:
|
||||
app_context: Application context (Flask or standalone)
|
||||
context_vars: Python contextvars to preserve
|
||||
user: User object (optional)
|
||||
"""
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
"""Get application context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context | None:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enter this execution context.
|
||||
|
||||
This is a convenience method that creates a context manager.
|
||||
"""
|
||||
# Restore context variables if provided
|
||||
if self._context_vars:
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Enter app context if available
|
||||
if self._app_context is not None:
|
||||
with self._app_context.enter():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
"""
|
||||
Null implementation of AppContext for non-Flask environments.
|
||||
|
||||
This is used when running without Flask (e.g., in tests or standalone mode).
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize null app context.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
|
||||
def set_extension(self, name: str, extension: Any) -> None:
|
||||
"""Set extension by name."""
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
|
||||
class ExecutionContextBuilder:
|
||||
"""
|
||||
Builder for creating ExecutionContext instances.
|
||||
|
||||
This provides a fluent API for building execution contexts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._app_context: AppContext | None = None
|
||||
self._context_vars: contextvars.Context | None = None
|
||||
self._user: Any = None
|
||||
|
||||
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
|
||||
"""Set application context."""
|
||||
self._app_context = app_context
|
||||
return self
|
||||
|
||||
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
|
||||
"""Set context variables."""
|
||||
self._context_vars = context_vars
|
||||
return self
|
||||
|
||||
def with_user(self, user: Any) -> "ExecutionContextBuilder":
|
||||
"""Set user."""
|
||||
self._user = user
|
||||
return self
|
||||
|
||||
def build(self) -> ExecutionContext:
|
||||
"""Build the execution context."""
|
||||
return ExecutionContext(
|
||||
app_context=self._app_context,
|
||||
context_vars=self._context_vars,
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
from context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
@ -7,13 +7,15 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
@ -157,8 +159,17 @@ class GraphEngine:
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture execution context for worker threads
|
||||
execution_context = capture_current_context()
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
@ -166,7 +177,8 @@ class GraphEngine:
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
|
||||
@ -5,27 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
@ -45,7 +44,8 @@ class Worker(threading.Thread):
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
@ -56,17 +56,19 @@ class Worker(threading.Thread):
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
execution_context: Optional execution context for context preservation
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
@ -133,9 +135,11 @@ class Worker(threading.Thread):
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
with self._execution_context:
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
|
||||
@ -8,10 +8,9 @@ DynamicScaler, and WorkerFactory into a single class.
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
@ -21,6 +20,11 @@ from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
@ -38,7 +42,8 @@ class WorkerPool:
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
@ -52,7 +57,8 @@ class WorkerPool:
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
execution_context: Optional execution context for context preservation
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
@ -61,7 +67,8 @@ class WorkerPool:
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._execution_context = execution_context
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._layers = layers
|
||||
|
||||
# Scaling parameters with defaults
|
||||
@ -145,7 +152,8 @@ class WorkerPool:
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
|
||||
@ -62,6 +62,21 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
# Ensure storage_key is loaded for File objects
|
||||
files_to_check = value if isinstance(value, list) else [value]
|
||||
files_needing_storage_key = [
|
||||
f for f in files_to_check if isinstance(f, File) and not f.storage_key and f.related_id
|
||||
]
|
||||
if files_needing_storage_key:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
|
||||
storage_key_loader.load_storage_keys(files_needing_storage_key)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
@ -415,6 +430,16 @@ def _download_file_content(file: File) -> bytes:
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Check if storage_key is set
|
||||
if not file.storage_key:
|
||||
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
|
||||
|
||||
# Check if file exists before downloading
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not storage.exists(file.storage_key):
|
||||
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
|
||||
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -37,6 +39,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -48,7 +51,6 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -250,7 +252,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
execution_context=self._capture_execution_context(),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -303,10 +306,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
@ -335,12 +339,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _capture_execution_context(self) -> "IExecutionContext":
|
||||
"""Capture current execution context for parallel iterations."""
|
||||
from core.workflow.context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
|
||||
@ -158,3 +158,5 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -16,7 +18,9 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -67,7 +71,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -163,6 +180,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
@ -173,9 +193,307 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info("No segments need summary generation for document %s", document.id)
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment.id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
logger.info(
|
||||
"Successfully generated summary index for %s segments in document %s",
|
||||
len(segments_to_process),
|
||||
document.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document.id)
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(
|
||||
"Queuing summary index generation task for document %s (production mode)",
|
||||
document.id,
|
||||
)
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info("Summary index generation task queued for document %s", document.id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
return index_processor.format_preview(chunks)
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset.id,
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
errors: list[Exception] = []
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
timeout_error_msg = (
|
||||
f"Summary generation timeout: {len(not_done)} chunks did not complete within {timeout_seconds}s"
|
||||
)
|
||||
logger.warning("%s. Cancelling remaining tasks...", timeout_error_msg)
|
||||
# In preview mode, timeout is also an error
|
||||
errors.append(TimeoutError(timeout_error_msg))
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
# Collect exceptions from completed futures
|
||||
for future in done:
|
||||
try:
|
||||
future.result() # This will raise any exception that occurred
|
||||
except Exception as e:
|
||||
logger.exception("Error in summary generation future")
|
||||
errors.append(e)
|
||||
|
||||
# In preview mode, if there are any errors, fail the request
|
||||
if errors:
|
||||
error_messages = [str(e) for e in errors]
|
||||
error_summary = (
|
||||
f"Failed to generate summaries for {len(errors)} chunk(s). "
|
||||
f"Errors: {'; '.join(error_messages[:3])}" # Show first 3 errors
|
||||
)
|
||||
if len(errors) > 3:
|
||||
error_summary += f" (and {len(errors) - 3} more)"
|
||||
logger.error("Summary generation failed in preview mode: %s", error_summary)
|
||||
raise KnowledgeIndexNodeError(error_summary)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output["preview"]),
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset | None = None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
|
||||
matched_count,
|
||||
dataset.id,
|
||||
document.id,
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@ -119,16 +119,14 @@ elif [[ "${MODE}" == "job" ]]; then
|
||||
|
||||
else
|
||||
if [[ "${DEBUG}" == "true" ]]; then
|
||||
export HOST=${DIFY_BIND_ADDRESS:-0.0.0.0}
|
||||
export PORT=${DIFY_PORT:-5001}
|
||||
exec python -m app
|
||||
exec flask run --host=${DIFY_BIND_ADDRESS:-0.0.0.0} --port=${DIFY_PORT:-5001} --debug
|
||||
else
|
||||
exec gunicorn \
|
||||
--bind "${DIFY_BIND_ADDRESS:-0.0.0.0}:${DIFY_PORT:-5001}" \
|
||||
--workers ${SERVER_WORKER_AMOUNT:-1} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-geventwebsocket.gunicorn.workers.GeventWebSocketWorker} \
|
||||
--worker-class ${SERVER_WORKER_CLASS:-gevent} \
|
||||
--worker-connections ${SERVER_WORKER_CONNECTIONS:-10} \
|
||||
--timeout ${GUNICORN_TIMEOUT:-200} \
|
||||
app:socketio_app
|
||||
app:app
|
||||
fi
|
||||
fi
|
||||
|
||||
@ -102,6 +102,8 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
|
||||
@ -1,5 +0,0 @@
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
|
||||
@ -39,6 +39,14 @@ dataset_retrieval_model_fields = {
|
||||
"score_threshold_enabled": fields.Boolean,
|
||||
"score_threshold": fields.Float,
|
||||
}
|
||||
|
||||
dataset_summary_index_fields = {
|
||||
"enable": fields.Boolean,
|
||||
"model_name": fields.String,
|
||||
"model_provider_name": fields.String,
|
||||
"summary_prompt": fields.String,
|
||||
}
|
||||
|
||||
external_retrieval_model_fields = {
|
||||
"top_k": fields.Integer,
|
||||
"score_threshold": fields.Float,
|
||||
@ -83,6 +91,7 @@ dataset_detail_fields = {
|
||||
"embedding_model_provider": fields.String,
|
||||
"embedding_available": fields.Boolean,
|
||||
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
|
||||
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
|
||||
@ -33,6 +33,11 @@ document_fields = {
|
||||
"hit_count": fields.Integer,
|
||||
"doc_form": fields.String,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
# Whether this document needs summary index generation
|
||||
"need_summary": fields.Boolean,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
@ -60,6 +65,10 @@ document_with_segments_fields = {
|
||||
"completed_segments": fields.Integer,
|
||||
"total_segments": fields.Integer,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status:
|
||||
# "SUMMARIZING" (when task is queued and generating)
|
||||
"summary_index_status": fields.String,
|
||||
"need_summary": fields.Boolean, # Whether this document needs summary index generation
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
|
||||
@ -58,4 +58,5 @@ hit_testing_record_fields = {
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
|
||||
@ -1,17 +0,0 @@
|
||||
from flask_restx import fields
|
||||
|
||||
online_user_partial_fields = {
|
||||
"user_id": fields.String,
|
||||
"username": fields.String,
|
||||
"avatar": fields.String,
|
||||
"sid": fields.String,
|
||||
}
|
||||
|
||||
workflow_online_users_fields = {
|
||||
"workflow_id": fields.String,
|
||||
"users": fields.List(fields.Nested(online_user_partial_fields)),
|
||||
}
|
||||
|
||||
online_user_list_fields = {
|
||||
"data": fields.List(fields.Nested(workflow_online_users_fields)),
|
||||
}
|
||||
@ -49,4 +49,5 @@ segment_fields = {
|
||||
"stopped_at": TimestampField,
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||
"summary": fields.String, # Summary content for the segment
|
||||
}
|
||||
|
||||
@ -1,96 +0,0 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
# basic account fields for comments
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
}
|
||||
|
||||
# Comment mention fields
|
||||
workflow_comment_mention_fields = {
|
||||
"mentioned_user_id": fields.String,
|
||||
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_id": fields.String,
|
||||
}
|
||||
|
||||
# Comment reply fields
|
||||
workflow_comment_reply_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Basic comment fields (for list views)
|
||||
workflow_comment_basic_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_count": fields.Integer,
|
||||
"mention_count": fields.Integer,
|
||||
"participants": fields.List(fields.Nested(account_fields)),
|
||||
}
|
||||
|
||||
# Detailed comment fields (for single comment view)
|
||||
workflow_comment_detail_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
|
||||
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
|
||||
}
|
||||
|
||||
# Comment creation response fields (simplified)
|
||||
workflow_comment_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment update response fields (simplified)
|
||||
workflow_comment_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment resolve response fields
|
||||
workflow_comment_resolve_fields = {
|
||||
"id": fields.String,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
}
|
||||
|
||||
# Reply creation response fields (simplified)
|
||||
workflow_comment_reply_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Reply update response fields
|
||||
workflow_comment_reply_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
Workspace permission helper functions.
|
||||
|
||||
These helpers check both billing/plan level and workspace-specific policy level permissions.
|
||||
Checks are performed at two levels:
|
||||
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
|
||||
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_workspace_member_invite_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows member invitations at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - For future expansion (currently no plan-level restriction)
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits member invitations
|
||||
"""
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_member_invite:
|
||||
raise Forbidden("Workspace policy prohibits member invitations")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
|
||||
|
||||
|
||||
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows owner transfer at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - SANDBOX plan blocks owner transfer
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
|
||||
"""
|
||||
features = FeatureService.get_features(workspace_id)
|
||||
if not features.is_allow_transfer_workspace:
|
||||
raise Forbidden("Your current plan does not allow workspace ownership transfer")
|
||||
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_owner_transfer:
|
||||
raise Forbidden("Workspace policy prohibits ownership transfer")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
|
||||
@ -0,0 +1,69 @@
|
||||
"""add SummaryIndex feature
|
||||
|
||||
Revision ID: 562dcce7d77c
|
||||
Revises: 03ea244985ce
|
||||
Create Date: 2026-01-12 13:58:40.584802
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '562dcce7d77c'
|
||||
down_revision = '03ea244985ce'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('document_segment_summary',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summary_pkey')
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summary_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_column('need_summary')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('summary_index_setting')
|
||||
|
||||
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_summary_status_idx')
|
||||
batch_op.drop_index('document_segment_summary_document_id_idx')
|
||||
batch_op.drop_index('document_segment_summary_dataset_id_idx')
|
||||
batch_op.drop_index('document_segment_summary_chunk_id_idx')
|
||||
|
||||
op.drop_table('document_segment_summary')
|
||||
# ### end Alembic commands ###
|
||||
@ -1,35 +0,0 @@
|
||||
"""change workflow node execution workflow_run index
|
||||
|
||||
Revision ID: 288345cd01d1
|
||||
Revises: 3334862ee907
|
||||
Create Date: 2026-01-16 17:15:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "288345cd01d1"
|
||||
down_revision = "3334862ee907"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_node_execution_workflow_run_idx")
|
||||
batch_op.create_index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
["workflow_run_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
|
||||
batch_op.create_index(
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
|
||||
unique=False,
|
||||
)
|
||||
@ -1,90 +0,0 @@
|
||||
"""Add workflow comments table
|
||||
|
||||
Revision ID: 227822d22895
|
||||
Revises: 288345cd01d1
|
||||
Create Date: 2025-08-22 17:26:15.255980
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '227822d22895'
|
||||
down_revision = '288345cd01d1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('workflow_comments',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('position_x', sa.Float(), nullable=False),
|
||||
sa.Column('position_y', sa.Float(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('resolved', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('resolved_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('resolved_by', models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comments_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
|
||||
batch_op.create_index('workflow_comments_app_idx', ['tenant_id', 'app_id'], unique=False)
|
||||
batch_op.create_index('workflow_comments_created_at_idx', ['created_at'], unique=False)
|
||||
|
||||
op.create_table('workflow_comment_replies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('content', sa.Text(), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_replies_comment_id_fkey'), ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comment_replies_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
|
||||
batch_op.create_index('comment_replies_comment_idx', ['comment_id'], unique=False)
|
||||
batch_op.create_index('comment_replies_created_at_idx', ['created_at'], unique=False)
|
||||
|
||||
op.create_table('workflow_comment_mentions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuidv7()'), nullable=False),
|
||||
sa.Column('comment_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('reply_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('mentioned_user_id', models.types.StringUUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['comment_id'], ['workflow_comments.id'], name=op.f('workflow_comment_mentions_comment_id_fkey'), ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['reply_id'], ['workflow_comment_replies.id'], name=op.f('workflow_comment_mentions_reply_id_fkey'), ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id', name='workflow_comment_mentions_pkey')
|
||||
)
|
||||
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
|
||||
batch_op.create_index('comment_mentions_comment_idx', ['comment_id'], unique=False)
|
||||
batch_op.create_index('comment_mentions_reply_idx', ['reply_id'], unique=False)
|
||||
batch_op.create_index('comment_mentions_user_idx', ['mentioned_user_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_comment_mentions', schema=None) as batch_op:
|
||||
batch_op.drop_index('comment_mentions_user_idx')
|
||||
batch_op.drop_index('comment_mentions_reply_idx')
|
||||
batch_op.drop_index('comment_mentions_comment_idx')
|
||||
|
||||
op.drop_table('workflow_comment_mentions')
|
||||
with op.batch_alter_table('workflow_comment_replies', schema=None) as batch_op:
|
||||
batch_op.drop_index('comment_replies_created_at_idx')
|
||||
batch_op.drop_index('comment_replies_comment_idx')
|
||||
|
||||
op.drop_table('workflow_comment_replies')
|
||||
with op.batch_alter_table('workflow_comments', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_comments_created_at_idx')
|
||||
batch_op.drop_index('workflow_comments_app_idx')
|
||||
|
||||
op.drop_table('workflow_comments')
|
||||
# ### end Alembic commands ###
|
||||
@ -9,11 +9,6 @@ from .account import (
|
||||
TenantStatus,
|
||||
)
|
||||
from .api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||
from .comment import (
|
||||
WorkflowComment,
|
||||
WorkflowCommentMention,
|
||||
WorkflowCommentReply,
|
||||
)
|
||||
from .dataset import (
|
||||
AppDatasetJoin,
|
||||
Dataset,
|
||||
@ -202,9 +197,6 @@ __all__ = [
|
||||
"Workflow",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowComment",
|
||||
"WorkflowCommentMention",
|
||||
"WorkflowCommentReply",
|
||||
"WorkflowNodeExecutionModel",
|
||||
"WorkflowNodeExecutionOffload",
|
||||
"WorkflowNodeExecutionTriggeredFrom",
|
||||
|
||||
@ -1,210 +0,0 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
since an app has only one draft workflow at a time and comments should persist
|
||||
across workflow version changes.
|
||||
|
||||
Attributes:
|
||||
id: Comment ID
|
||||
tenant_id: Workspace ID
|
||||
app_id: App ID (primary association, comments belong to apps)
|
||||
position_x: X coordinate on canvas
|
||||
position_y: Y coordinate on canvas
|
||||
content: Comment content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
updated_at: Last update time
|
||||
resolved: Whether comment is resolved
|
||||
resolved_at: Resolution time
|
||||
resolved_by: Resolver account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
@property
|
||||
def resolved_by_account(self):
|
||||
"""Get resolver account."""
|
||||
if hasattr(self, "_resolved_by_account_cache"):
|
||||
return self._resolved_by_account_cache
|
||||
if self.resolved_by:
|
||||
return db.session.get(Account, self.resolved_by)
|
||||
return None
|
||||
|
||||
def cache_resolved_by_account(self, account: Account | None) -> None:
|
||||
"""Cache resolver account to avoid extra queries."""
|
||||
self._resolved_by_account_cache = account
|
||||
|
||||
@property
|
||||
def reply_count(self):
|
||||
"""Get reply count."""
|
||||
return len(self.replies)
|
||||
|
||||
@property
|
||||
def mention_count(self):
|
||||
"""Get mention count."""
|
||||
return len(self.mentions)
|
||||
|
||||
@property
|
||||
def participants(self):
|
||||
"""Get all participants (creator + repliers + mentioned users)."""
|
||||
participant_ids = set()
|
||||
|
||||
# Add comment creator
|
||||
participant_ids.add(self.created_by)
|
||||
|
||||
# Add reply creators
|
||||
participant_ids.update(reply.created_by for reply in self.replies)
|
||||
|
||||
# Add mentioned users
|
||||
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
|
||||
|
||||
# Get account objects
|
||||
participants = []
|
||||
for user_id in participant_ids:
|
||||
account = db.session.get(Account, user_id)
|
||||
if account:
|
||||
participants.append(account)
|
||||
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
id: Reply ID
|
||||
comment_id: Parent comment ID
|
||||
content: Reply content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
cannot access workflow canvas and commenting features.
|
||||
|
||||
Attributes:
|
||||
id: Mention ID
|
||||
comment_id: Parent comment ID
|
||||
mentioned_user_id: Mentioned account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
"""Get mentioned account."""
|
||||
if hasattr(self, "_mentioned_user_account_cache"):
|
||||
return self._mentioned_user_account_cache
|
||||
return db.session.get(Account, self.mentioned_user_id)
|
||||
|
||||
def cache_mentioned_user_account(self, account: Account | None) -> None:
|
||||
"""Cache mentioned account to avoid extra queries."""
|
||||
self._mentioned_user_account_cache = account
|
||||
@ -72,6 +72,7 @@ class Dataset(Base):
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
@ -419,6 +420,7 @@ class Document(Base):
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
||||
@ -1575,3 +1577,35 @@ class SegmentAttachmentBinding(Base):
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
__tablename__ = "document_segment_summary"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_summary_pkey"),
|
||||
sa.Index("document_segment_summary_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_summary_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_summary_chunk_id_idx", "chunk_id"),
|
||||
sa.Index("document_segment_summary_status_idx", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# corresponds to DocumentSegment.id or parent chunk id
|
||||
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"
|
||||
|
||||
@ -401,7 +401,7 @@ class Workflow(Base): # bug
|
||||
|
||||
:return: hash
|
||||
"""
|
||||
entity = {"graph": self.graph_dict}
|
||||
entity = {"graph": self.graph_dict, "features": self.features_dict}
|
||||
|
||||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||
|
||||
@ -781,7 +781,11 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
|
||||
@ -21,7 +21,6 @@ dependencies = [
|
||||
"flask-orjson~=2.0.0",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=25.9.1",
|
||||
"gevent-websocket~=0.10.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-python-client==2.90.0",
|
||||
@ -73,7 +72,6 @@ dependencies = [
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.1.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"python-socketio~=5.13.0",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=6.1.0",
|
||||
|
||||
@ -13,8 +13,6 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
@ -132,18 +130,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
"""
|
||||
...
|
||||
|
||||
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions and offloads for the given workflow run ids.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions and offloads for the given workflow run ids.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -7,15 +7,17 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import TypedDict, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy import asc, delete, desc, func, select, tuple_
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
@ -47,6 +49,26 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
@staticmethod
|
||||
def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str:
|
||||
"""
|
||||
Map workflow run triggered_from values to workflow node execution triggered_from values.
|
||||
"""
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
WorkflowRunTriggeredFrom.SCHEDULE.value,
|
||||
WorkflowRunTriggeredFrom.PLUGIN.value,
|
||||
WorkflowRunTriggeredFrom.WEBHOOK.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value
|
||||
return ""
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -294,16 +316,51 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
|
||||
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
class RunContext(TypedDict):
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
triggered_from: str
|
||||
|
||||
@staticmethod
|
||||
def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
Delete node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
|
||||
Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id)
|
||||
by filtering on those columns with tuple IN.
|
||||
"""
|
||||
if not run_ids:
|
||||
if not runs:
|
||||
return 0, 0
|
||||
|
||||
run_ids = list(run_ids)
|
||||
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
|
||||
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
|
||||
node_execution_ids = session.scalars(
|
||||
select(WorkflowNodeExecutionModel.id).where(
|
||||
tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
)
|
||||
).all()
|
||||
|
||||
if not node_execution_ids:
|
||||
return 0, 0
|
||||
|
||||
offloads_deleted = (
|
||||
cast(
|
||||
@ -320,32 +377,55 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
node_executions_deleted = (
|
||||
cast(
|
||||
CursorResult,
|
||||
session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
|
||||
session.execute(
|
||||
delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
|
||||
),
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
|
||||
return node_executions_deleted, offloads_deleted
|
||||
|
||||
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
@staticmethod
|
||||
def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
Count node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
"""
|
||||
if not run_ids:
|
||||
if not runs:
|
||||
return 0, 0
|
||||
|
||||
run_ids = list(run_ids)
|
||||
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
tuple_filter = tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
|
||||
node_executions_count = (
|
||||
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
|
||||
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0
|
||||
)
|
||||
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
|
||||
offloads_count = (
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowNodeExecutionOffload)
|
||||
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
|
||||
.join(
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id,
|
||||
)
|
||||
.where(tuple_filter)
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@ -1,147 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TypedDict
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
SESSION_STATE_TTL_SECONDS = 3600
|
||||
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
|
||||
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
|
||||
WS_SID_MAP_PREFIX = "ws_sid_map:"
|
||||
|
||||
|
||||
class WorkflowSessionInfo(TypedDict):
|
||||
user_id: str
|
||||
username: str
|
||||
avatar: str | None
|
||||
sid: str
|
||||
connected_at: int
|
||||
|
||||
|
||||
class SidMapping(TypedDict):
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class WorkflowCollaborationRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_client
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(redis_client={self._redis})"
|
||||
|
||||
@staticmethod
|
||||
def workflow_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def leader_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def sid_key(sid: str) -> str:
|
||||
return f"{WS_SID_MAP_PREFIX}{sid}"
|
||||
|
||||
@staticmethod
|
||||
def _decode(value: str | bytes | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
sid_key = self.sid_key(sid)
|
||||
if self._redis.exists(workflow_key):
|
||||
self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
|
||||
if self._redis.exists(sid_key):
|
||||
self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info))
|
||||
self._redis.set(
|
||||
self.sid_key(session_info["sid"]),
|
||||
json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}),
|
||||
ex=SESSION_STATE_TTL_SECONDS,
|
||||
)
|
||||
self.refresh_session_state(workflow_id, session_info["sid"])
|
||||
|
||||
def get_sid_mapping(self, sid: str) -> SidMapping | None:
|
||||
raw = self._redis.get(self.sid_key(sid))
|
||||
if not raw:
|
||||
return None
|
||||
value = self._decode(raw)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def delete_session(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.hdel(self.workflow_key(workflow_id), sid)
|
||||
self._redis.delete(self.sid_key(sid))
|
||||
|
||||
def session_exists(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.hexists(self.workflow_key(workflow_id), sid))
|
||||
|
||||
def sid_mapping_exists(self, sid: str) -> bool:
|
||||
return bool(self._redis.exists(self.sid_key(sid)))
|
||||
|
||||
def get_session_sids(self, workflow_id: str) -> list[str]:
|
||||
raw_sids = self._redis.hkeys(self.workflow_key(workflow_id))
|
||||
decoded_sids: list[str] = []
|
||||
for sid in raw_sids:
|
||||
decoded = self._decode(sid)
|
||||
if decoded:
|
||||
decoded_sids.append(decoded)
|
||||
return decoded_sids
|
||||
|
||||
def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
|
||||
sessions_json = self._redis.hgetall(self.workflow_key(workflow_id))
|
||||
users: list[WorkflowSessionInfo] = []
|
||||
|
||||
for session_info_json in sessions_json.values():
|
||||
value = self._decode(session_info_json)
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
session_info = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
if not isinstance(session_info, dict):
|
||||
continue
|
||||
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
|
||||
continue
|
||||
|
||||
users.append(
|
||||
{
|
||||
"user_id": str(session_info["user_id"]),
|
||||
"username": str(session_info["username"]),
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": str(session_info["sid"]),
|
||||
"connected_at": int(session_info.get("connected_at") or 0),
|
||||
}
|
||||
)
|
||||
|
||||
return users
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
raw = self._redis.get(self.leader_key(workflow_id))
|
||||
return self._decode(raw)
|
||||
|
||||
def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS))
|
||||
|
||||
def set_leader(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def delete_leader(self, workflow_id: str) -> None:
|
||||
self._redis.delete(self.leader_key(workflow_id))
|
||||
|
||||
def expire_leader(self, workflow_id: str) -> None:
|
||||
self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS)
|
||||
@ -1381,11 +1381,6 @@ class RegisterService:
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
|
||||
@ -13,11 +13,10 @@ import sqlalchemy as sa
|
||||
from redis.exceptions import LockNotOwnedError
|
||||
from sqlalchemy import exists, func, select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper.name_generator import generate_incremental_name
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
@ -74,7 +73,6 @@ from services.errors.document import DocumentIndexingError
|
||||
from services.errors.file import FileNotExistsError
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
from services.tag_service import TagService
|
||||
from services.vector_service import VectorService
|
||||
@ -89,6 +87,7 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
@ -476,6 +475,11 @@ class DatasetService:
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# Update summary index setting if provided
|
||||
summary_index_setting = data.get("summary_index_setting", None)
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
||||
# Update basic dataset properties
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
dataset.description = data.get("description", dataset.description)
|
||||
@ -558,12 +562,18 @@ class DatasetService:
|
||||
# Handle indexing technique changes and embedding model updates
|
||||
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
|
||||
|
||||
# Check if summary_index_setting model changed (before updating database)
|
||||
summary_model_changed = DatasetService._check_summary_index_setting_model_changed(dataset, data)
|
||||
|
||||
# Add metadata fields
|
||||
filtered_data["updated_by"] = user.id
|
||||
filtered_data["updated_at"] = naive_utc_now()
|
||||
# update Retrieval model
|
||||
if data.get("retrieval_model"):
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update summary index setting
|
||||
if data.get("summary_index_setting"):
|
||||
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
@ -572,12 +582,30 @@ class DatasetService:
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# update pipeline knowledge base node data
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
# If embedding_model changed, also regenerate summary vectors
|
||||
if action == "update":
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Trigger summary index regeneration if summary model changed
|
||||
if summary_model_changed:
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="summary_model_changed",
|
||||
regenerate_vectors_only=False,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@ -616,6 +644,7 @@ class DatasetService:
|
||||
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
|
||||
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
|
||||
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
|
||||
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
|
||||
node["data"] = knowledge_index_node_data
|
||||
updated = True
|
||||
except Exception:
|
||||
@ -854,6 +883,53 @@ class DatasetService:
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@staticmethod
|
||||
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if summary_index_setting model (model_name or model_provider_name) has changed.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
|
||||
Returns:
|
||||
bool: True if summary model changed, False otherwise
|
||||
"""
|
||||
# Check if summary_index_setting is being updated
|
||||
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
|
||||
return False
|
||||
|
||||
new_summary_setting = data.get("summary_index_setting")
|
||||
old_summary_setting = dataset.summary_index_setting
|
||||
|
||||
# If old setting doesn't exist or is disabled, no need to regenerate
|
||||
if not old_summary_setting or not old_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# If new setting is disabled, no need to regenerate
|
||||
if not new_summary_setting or not new_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# Compare model_name and model_provider_name
|
||||
old_model_name = old_summary_setting.get("model_name")
|
||||
old_model_provider = old_summary_setting.get("model_provider_name")
|
||||
new_model_name = new_summary_setting.get("model_name")
|
||||
new_model_provider = new_summary_setting.get("model_provider_name")
|
||||
|
||||
# Check if model changed
|
||||
if old_model_name != new_model_name or old_model_provider != new_model_provider:
|
||||
logger.info(
|
||||
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
|
||||
dataset.id,
|
||||
old_model_provider,
|
||||
old_model_name,
|
||||
new_model_provider,
|
||||
new_model_name,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
@ -889,6 +965,9 @@ class DatasetService:
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
@ -994,6 +1073,9 @@ class DatasetService:
|
||||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
@ -1164,7 +1246,6 @@ class DocumentService:
|
||||
Document.archived.is_(True),
|
||||
),
|
||||
}
|
||||
DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION = ".zip"
|
||||
|
||||
@classmethod
|
||||
def normalize_display_status(cls, status: str | None) -> str | None:
|
||||
@ -1291,143 +1372,6 @@ class DocumentService:
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_documents_by_ids(dataset_id: str, document_ids: Sequence[str]) -> Sequence[Document]:
|
||||
"""Fetch documents for a dataset in a single batch query."""
|
||||
if not document_ids:
|
||||
return []
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
# Fetch all requested documents in one query to avoid N+1 lookups.
|
||||
documents: Sequence[Document] = db.session.scalars(
|
||||
select(Document).where(
|
||||
Document.dataset_id == dataset_id,
|
||||
Document.id.in_(document_id_list),
|
||||
)
|
||||
).all()
|
||||
return documents
|
||||
|
||||
@staticmethod
|
||||
def get_document_download_url(document: Document) -> str:
|
||||
"""
|
||||
Return a signed download URL for an upload-file document.
|
||||
"""
|
||||
upload_file = DocumentService._get_upload_file_for_upload_file_document(document)
|
||||
return file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
@staticmethod
|
||||
def prepare_document_batch_download_zip(
|
||||
*,
|
||||
dataset_id: str,
|
||||
document_ids: Sequence[str],
|
||||
tenant_id: str,
|
||||
current_user: Account,
|
||||
) -> tuple[list[UploadFile], str]:
|
||||
"""
|
||||
Resolve upload files for batch ZIP downloads and generate a client-visible filename.
|
||||
"""
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
upload_files_by_document_id = DocumentService._get_upload_files_by_document_id_for_zip_download(
|
||||
dataset_id=dataset_id,
|
||||
document_ids=document_ids,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
upload_files = [upload_files_by_document_id[document_id] for document_id in document_ids]
|
||||
download_name = DocumentService._generate_document_batch_download_zip_filename()
|
||||
return upload_files, download_name
|
||||
|
||||
@staticmethod
|
||||
def _generate_document_batch_download_zip_filename() -> str:
|
||||
"""
|
||||
Generate a random attachment filename for the batch download ZIP.
|
||||
"""
|
||||
return f"{uuid.uuid4().hex}{DocumentService.DOCUMENT_BATCH_DOWNLOAD_ZIP_FILENAME_EXTENSION}"
|
||||
|
||||
@staticmethod
|
||||
def _get_upload_file_id_for_upload_file_document(
|
||||
document: Document,
|
||||
*,
|
||||
invalid_source_message: str,
|
||||
missing_file_message: str,
|
||||
) -> str:
|
||||
"""
|
||||
Normalize and validate `Document -> UploadFile` linkage for download flows.
|
||||
"""
|
||||
if document.data_source_type != "upload_file":
|
||||
raise NotFound(invalid_source_message)
|
||||
|
||||
data_source_info: dict[str, Any] = document.data_source_info_dict or {}
|
||||
upload_file_id: str | None = data_source_info.get("upload_file_id")
|
||||
if not upload_file_id:
|
||||
raise NotFound(missing_file_message)
|
||||
|
||||
return str(upload_file_id)
|
||||
|
||||
@staticmethod
|
||||
def _get_upload_file_for_upload_file_document(document: Document) -> UploadFile:
|
||||
"""
|
||||
Load the `UploadFile` row for an upload-file document.
|
||||
"""
|
||||
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="Document does not have an uploaded file to download.",
|
||||
missing_file_message="Uploaded file not found.",
|
||||
)
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(document.tenant_id, [upload_file_id])
|
||||
upload_file = upload_files_by_id.get(upload_file_id)
|
||||
if not upload_file:
|
||||
raise NotFound("Uploaded file not found.")
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def _get_upload_files_by_document_id_for_zip_download(
|
||||
*,
|
||||
dataset_id: str,
|
||||
document_ids: Sequence[str],
|
||||
tenant_id: str,
|
||||
) -> dict[str, UploadFile]:
|
||||
"""
|
||||
Batch load upload files keyed by document id for ZIP downloads.
|
||||
"""
|
||||
document_id_list: list[str] = [str(document_id) for document_id in document_ids]
|
||||
|
||||
documents = DocumentService.get_documents_by_ids(dataset_id, document_id_list)
|
||||
documents_by_id: dict[str, Document] = {str(document.id): document for document in documents}
|
||||
|
||||
missing_document_ids: set[str] = set(document_id_list) - set(documents_by_id.keys())
|
||||
if missing_document_ids:
|
||||
raise NotFound("Document not found.")
|
||||
|
||||
upload_file_ids: list[str] = []
|
||||
upload_file_ids_by_document_id: dict[str, str] = {}
|
||||
for document_id, document in documents_by_id.items():
|
||||
if document.tenant_id != tenant_id:
|
||||
raise Forbidden("No permission.")
|
||||
|
||||
upload_file_id = DocumentService._get_upload_file_id_for_upload_file_document(
|
||||
document,
|
||||
invalid_source_message="Only uploaded-file documents can be downloaded as ZIP.",
|
||||
missing_file_message="Only uploaded-file documents can be downloaded as ZIP.",
|
||||
)
|
||||
upload_file_ids.append(upload_file_id)
|
||||
upload_file_ids_by_document_id[document_id] = upload_file_id
|
||||
|
||||
upload_files_by_id = FileService.get_upload_files_by_ids(tenant_id, upload_file_ids)
|
||||
missing_upload_file_ids: set[str] = set(upload_file_ids) - set(upload_files_by_id.keys())
|
||||
if missing_upload_file_ids:
|
||||
raise NotFound("Only uploaded-file documents can be downloaded as ZIP.")
|
||||
|
||||
return {
|
||||
document_id: upload_files_by_id[upload_file_id]
|
||||
for document_id, upload_file_id in upload_file_ids_by_document_id.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_document_by_id(document_id: str) -> Document | None:
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
@ -1964,6 +1908,8 @@ class DocumentService:
|
||||
DuplicateDocumentIndexingTaskProxy(
|
||||
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||
).delay()
|
||||
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
|
||||
# to ensure segments are available. See tasks/document_indexing_task.py
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
@ -2268,6 +2214,11 @@ class DocumentService:
|
||||
name: str,
|
||||
batch: str,
|
||||
):
|
||||
# Set need_summary based on dataset's summary_index_setting
|
||||
need_summary = False
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
need_summary = True
|
||||
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -2281,6 +2232,7 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
need_summary=need_summary,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
@ -2505,6 +2457,7 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
summary_index_setting=knowledge_config.summary_index_setting,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
@ -2686,6 +2639,14 @@ class DocumentService:
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
|
||||
@ -3154,6 +3115,39 @@ class SegmentService:
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided and has changed
|
||||
if args.summary is not None:
|
||||
# Check if summary index is enabled
|
||||
has_summary_index = (
|
||||
dataset.indexing_technique == "high_quality"
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
)
|
||||
|
||||
if has_summary_index:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, update it
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
@ -3228,6 +3222,15 @@ class SegmentService:
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided
|
||||
if args.summary is not None:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
|
||||
@ -13,23 +13,6 @@ class WebAppSettings(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class WorkspacePermission(BaseModel):
|
||||
workspace_id: str = Field(
|
||||
description="The ID of the workspace.",
|
||||
alias="workspaceId",
|
||||
)
|
||||
allow_member_invite: bool = Field(
|
||||
description="Whether to allow members to invite new members to the workspace.",
|
||||
default=False,
|
||||
alias="allowMemberInvite",
|
||||
)
|
||||
allow_owner_transfer: bool = Field(
|
||||
description="Whether to allow owners to transfer ownership of the workspace.",
|
||||
default=False,
|
||||
alias="allowOwnerTransfer",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@ -61,16 +44,6 @@ class EnterpriseService:
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {data}") from e
|
||||
|
||||
class WorkspacePermissionService:
|
||||
@classmethod
|
||||
def get_permission(cls, workspace_id: str):
|
||||
if not workspace_id:
|
||||
raise ValueError("workspace_id must be provided.")
|
||||
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
|
||||
if not data or "permission" not in data:
|
||||
raise ValueError("No data found.")
|
||||
return WorkspacePermission.model_validate(data["permission"])
|
||||
|
||||
class WebAppAuth:
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
|
||||
|
||||
@ -119,6 +119,7 @@ class KnowledgeConfig(BaseModel):
|
||||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: str | None = None
|
||||
@ -141,6 +142,7 @@ class SegmentUpdateArgs(BaseModel):
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
||||
@ -116,6 +116,8 @@ class KnowledgeConfiguration(BaseModel):
|
||||
embedding_model: str = ""
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
# add summary index setting
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -161,7 +161,6 @@ class SystemFeatureModel(BaseModel):
|
||||
enable_email_code_login: bool = False
|
||||
enable_email_password_login: bool = True
|
||||
enable_social_oauth_login: bool = False
|
||||
enable_collaboration_mode: bool = False
|
||||
is_allow_register: bool = False
|
||||
is_allow_create_workspace: bool = False
|
||||
is_email_setup: bool = False
|
||||
@ -223,7 +222,6 @@ class FeatureService:
|
||||
system_features.enable_email_code_login = dify_config.ENABLE_EMAIL_CODE_LOGIN
|
||||
system_features.enable_email_password_login = dify_config.ENABLE_EMAIL_PASSWORD_LOGIN
|
||||
system_features.enable_social_oauth_login = dify_config.ENABLE_SOCIAL_OAUTH_LOGIN
|
||||
system_features.enable_collaboration_mode = dify_config.ENABLE_COLLABORATION_MODE
|
||||
system_features.is_allow_register = dify_config.ALLOW_REGISTER
|
||||
system_features.is_allow_create_workspace = dify_config.ALLOW_CREATE_WORKSPACE
|
||||
system_features.is_email_setup = dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != ""
|
||||
|
||||
@ -2,11 +2,7 @@ import base64
|
||||
import hashlib
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Iterator, Sequence
|
||||
from contextlib import contextmanager, suppress
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Literal, Union
|
||||
from zipfile import ZIP_DEFLATED, ZipFile
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
@ -21,7 +17,6 @@ from constants import (
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from core.rag.extractor.extract_processor import ExtractProcessor
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_tenant_id
|
||||
@ -172,9 +167,6 @@ class FileService:
|
||||
return upload_file
|
||||
|
||||
def get_file_preview(self, file_id: str):
|
||||
"""
|
||||
Return a short text preview extracted from a document file.
|
||||
"""
|
||||
with self._session_maker(expire_on_commit=False) as session:
|
||||
upload_file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||
|
||||
@ -261,101 +253,3 @@ class FileService:
|
||||
return
|
||||
storage.delete(upload_file.key)
|
||||
session.delete(upload_file)
|
||||
|
||||
@staticmethod
|
||||
def get_upload_files_by_ids(tenant_id: str, upload_file_ids: Sequence[str]) -> dict[str, UploadFile]:
|
||||
"""
|
||||
Fetch `UploadFile` rows for a tenant in a single batch query.
|
||||
|
||||
This is a generic `UploadFile` lookup helper (not dataset/document specific), so it lives in `FileService`.
|
||||
"""
|
||||
if not upload_file_ids:
|
||||
return {}
|
||||
|
||||
# Normalize and deduplicate ids before using them in the IN clause.
|
||||
upload_file_id_list: list[str] = [str(upload_file_id) for upload_file_id in upload_file_ids]
|
||||
unique_upload_file_ids: list[str] = list(set(upload_file_id_list))
|
||||
|
||||
# Fetch upload files in one query for efficient batch access.
|
||||
upload_files: Sequence[UploadFile] = db.session.scalars(
|
||||
select(UploadFile).where(
|
||||
UploadFile.tenant_id == tenant_id,
|
||||
UploadFile.id.in_(unique_upload_file_ids),
|
||||
)
|
||||
).all()
|
||||
return {str(upload_file.id): upload_file for upload_file in upload_files}
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_zip_entry_name(name: str) -> str:
|
||||
"""
|
||||
Sanitize a ZIP entry name to avoid path traversal and weird separators.
|
||||
|
||||
We keep this conservative: the upload flow already rejects `/` and `\\`, but older rows (or imported data)
|
||||
could still contain unsafe names.
|
||||
"""
|
||||
# Drop any directory components and prevent empty names.
|
||||
base = os.path.basename(name).strip() or "file"
|
||||
|
||||
# ZIP uses forward slashes as separators; remove any residual separator characters.
|
||||
return base.replace("/", "_").replace("\\", "_")
|
||||
|
||||
@staticmethod
|
||||
def _dedupe_zip_entry_name(original_name: str, used_names: set[str]) -> str:
|
||||
"""
|
||||
Return a unique ZIP entry name, inserting suffixes before the extension.
|
||||
"""
|
||||
# Keep the original name when it's not already used.
|
||||
if original_name not in used_names:
|
||||
return original_name
|
||||
|
||||
# Insert suffixes before the extension (e.g., "doc.txt" -> "doc (1).txt").
|
||||
stem, extension = os.path.splitext(original_name)
|
||||
suffix = 1
|
||||
while True:
|
||||
candidate = f"{stem} ({suffix}){extension}"
|
||||
if candidate not in used_names:
|
||||
return candidate
|
||||
suffix += 1
|
||||
|
||||
@staticmethod
|
||||
@contextmanager
|
||||
def build_upload_files_zip_tempfile(
|
||||
*,
|
||||
upload_files: Sequence[UploadFile],
|
||||
) -> Iterator[str]:
|
||||
"""
|
||||
Build a ZIP from `UploadFile`s and yield a tempfile path.
|
||||
|
||||
We yield a path (rather than an open file handle) to avoid "read of closed file" issues when Flask/Werkzeug
|
||||
streams responses. The caller is expected to keep this context open until the response is fully sent, then
|
||||
close it (e.g., via `response.call_on_close(...)`) to delete the tempfile.
|
||||
"""
|
||||
used_names: set[str] = set()
|
||||
|
||||
# Build a ZIP in a temp file and keep it on disk until the caller finishes streaming it.
|
||||
tmp_path: str | None = None
|
||||
try:
|
||||
with NamedTemporaryFile(mode="w+b", suffix=".zip", delete=False) as tmp:
|
||||
tmp_path = tmp.name
|
||||
with ZipFile(tmp, mode="w", compression=ZIP_DEFLATED) as zf:
|
||||
for upload_file in upload_files:
|
||||
# Ensure the entry name is safe and unique.
|
||||
safe_name = FileService._sanitize_zip_entry_name(upload_file.name)
|
||||
arcname = FileService._dedupe_zip_entry_name(safe_name, used_names)
|
||||
used_names.add(arcname)
|
||||
|
||||
# Stream file bytes from storage into the ZIP entry.
|
||||
with zf.open(arcname, "w") as entry:
|
||||
for chunk in storage.load(upload_file.key, stream=True):
|
||||
entry.write(chunk)
|
||||
|
||||
# Flush so `send_file(path, ...)` can re-open it safely on all platforms.
|
||||
tmp.flush()
|
||||
|
||||
assert tmp_path is not None
|
||||
yield tmp_path
|
||||
finally:
|
||||
# Remove the temp file when the context is closed (typically after the response finishes streaming).
|
||||
if tmp_path is not None:
|
||||
with suppress(FileNotFoundError):
|
||||
os.remove(tmp_path)
|
||||
|
||||
@ -10,7 +10,9 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
|
||||
@ -90,12 +92,9 @@ class WorkflowRunCleanup:
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
skipped_message,
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
@ -256,6 +255,21 @@ class WorkflowRunCleanup:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.count_by_run_ids(run_ids)
|
||||
|
||||
@staticmethod
|
||||
def _build_run_contexts(
|
||||
runs: Sequence[WorkflowRun],
|
||||
) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]:
|
||||
return [
|
||||
{
|
||||
"run_id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"app_id": run.app_id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"triggered_from": run.triggered_from,
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _empty_related_counts() -> dict[str, int]:
|
||||
return {
|
||||
@ -279,15 +293,9 @@ class WorkflowRunCleanup:
|
||||
)
|
||||
|
||||
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
)
|
||||
return repo.count_by_runs(session, run_ids)
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts)
|
||||
|
||||
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
)
|
||||
return repo.delete_by_runs(session, run_ids)
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts)
|
||||
|
||||
626
api/services/summary_index_service.py
Normal file
626
api/services/summary_index_service.py
Normal file
@ -0,0 +1,626 @@
|
||||
"""Summary index service for generating and managing document segment summaries."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummaryIndexService:
|
||||
"""Service for generating and managing summary indexes."""
|
||||
|
||||
@staticmethod
|
||||
def generate_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Generate summary for a single segment.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to generate summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_index_setting: Summary index configuration
|
||||
|
||||
Returns:
|
||||
Generated summary text
|
||||
|
||||
Raises:
|
||||
ValueError: If summary_index_setting is invalid or generation fails
|
||||
"""
|
||||
# Reuse the existing generate_summary method from ParagraphIndexProcessor
|
||||
# Use lazy import to avoid circular import
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
summary_content = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=segment.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
|
||||
if not summary_content:
|
||||
raise ValueError("Generated summary is empty")
|
||||
|
||||
return summary_content
|
||||
|
||||
@staticmethod
|
||||
def create_summary_record(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
status: str = "generating",
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Create or update a DocumentSegmentSummary record.
|
||||
If a summary record already exists for this segment, it will be updated instead of creating a new one.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to create summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: Generated summary content
|
||||
status: Summary status (default: "generating")
|
||||
|
||||
Returns:
|
||||
Created or updated DocumentSegmentSummary instance
|
||||
"""
|
||||
# Check if summary record already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
|
||||
if existing_summary:
|
||||
# Update existing record
|
||||
existing_summary.summary_content = summary_content
|
||||
existing_summary.status = status
|
||||
existing_summary.error = None # Clear any previous errors
|
||||
# Re-enable if it was disabled
|
||||
if not existing_summary.enabled:
|
||||
existing_summary.enabled = True
|
||||
existing_summary.disabled_at = None
|
||||
existing_summary.disabled_by = None
|
||||
db.session.add(existing_summary)
|
||||
db.session.flush()
|
||||
return existing_summary
|
||||
else:
|
||||
# Create new record (enabled by default)
|
||||
summary_record = DocumentSegmentSummary(
|
||||
dataset_id=dataset.id,
|
||||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
summary_content=summary_content,
|
||||
status=status,
|
||||
enabled=True, # Explicitly set enabled to True
|
||||
)
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
return summary_record
|
||||
|
||||
@staticmethod
|
||||
def vectorize_summary(
|
||||
summary_record: DocumentSegmentSummary,
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
) -> None:
|
||||
"""
|
||||
Vectorize summary and store in vector database.
|
||||
|
||||
Args:
|
||||
summary_record: DocumentSegmentSummary record
|
||||
segment: Original DocumentSegment
|
||||
dataset: Dataset containing the segment
|
||||
"""
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.warning(
|
||||
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
|
||||
dataset.id,
|
||||
)
|
||||
return
|
||||
|
||||
# Reuse existing index_node_id if available (like segment does), otherwise generate new one
|
||||
old_summary_node_id = summary_record.summary_index_node_id
|
||||
if old_summary_node_id:
|
||||
# Reuse existing index_node_id (like segment behavior)
|
||||
summary_index_node_id = old_summary_node_id
|
||||
else:
|
||||
# Generate new index node ID only for new summaries
|
||||
summary_index_node_id = str(uuid.uuid4())
|
||||
|
||||
# Always regenerate hash (in case summary content changed)
|
||||
summary_hash = helper.generate_text_hash(summary_record.summary_content)
|
||||
|
||||
# Delete old vector only if we're reusing the same index_node_id (to overwrite)
|
||||
# If index_node_id changed, the old vector should have been deleted elsewhere
|
||||
if old_summary_node_id and old_summary_node_id == summary_index_node_id:
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([old_summary_node_id])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.",
|
||||
segment.id,
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Create document with summary content and metadata
|
||||
summary_document = Document(
|
||||
page_content=summary_record.summary_content,
|
||||
metadata={
|
||||
"doc_id": summary_index_node_id,
|
||||
"doc_hash": summary_hash,
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": segment.document_id,
|
||||
"original_chunk_id": segment.id, # Key: link to original chunk
|
||||
"doc_type": DocType.TEXT,
|
||||
"is_summary": True, # Identifier for summary documents
|
||||
},
|
||||
)
|
||||
|
||||
# Vectorize and store with retry mechanism for connection errors
|
||||
max_retries = 3
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.add_texts([summary_document], duplicate_check=True)
|
||||
|
||||
# Success - update summary record with index node info
|
||||
summary_record.summary_index_node_id = summary_index_node_id
|
||||
summary_record.summary_index_node_hash = summary_hash
|
||||
summary_record.status = "completed"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
return # Success, exit function
|
||||
|
||||
except (ConnectionError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
# Check if it's a connection-related error that might be transient
|
||||
is_connection_error = any(
|
||||
keyword in error_str
|
||||
for keyword in [
|
||||
"connection",
|
||||
"disconnected",
|
||||
"timeout",
|
||||
"network",
|
||||
"could not connect",
|
||||
"server disconnected",
|
||||
"weaviate",
|
||||
]
|
||||
)
|
||||
|
||||
if is_connection_error and attempt < max_retries - 1:
|
||||
# Retry for connection errors
|
||||
wait_time = retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.warning(
|
||||
"Vectorization attempt %s/%s failed for segment %s: %s. Retrying in %.1f seconds...",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
segment.id,
|
||||
str(e),
|
||||
wait_time,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
# Final attempt failed or non-connection error - log and update status
|
||||
logger.error(
|
||||
"Failed to vectorize summary for segment %s after %s attempts: %s",
|
||||
segment.id,
|
||||
attempt + 1,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Vectorization failed: {str(e)}"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_and_vectorize_summary(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Generate summary for a segment and vectorize it.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to generate summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_index_setting: Summary index configuration
|
||||
|
||||
Returns:
|
||||
Created DocumentSegmentSummary instance
|
||||
|
||||
Raises:
|
||||
ValueError: If summary generation fails
|
||||
"""
|
||||
try:
|
||||
# Generate summary
|
||||
summary_content = SummaryIndexService.generate_summary_for_segment(segment, dataset, summary_index_setting)
|
||||
|
||||
# Create or update summary record (will handle overwrite internally)
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
)
|
||||
|
||||
# Vectorize summary (will delete old vector if exists before creating new one)
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Successfully generated and vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = str(e)
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_summaries_for_document(
|
||||
dataset: Dataset,
|
||||
document: DatasetDocument,
|
||||
summary_index_setting: dict,
|
||||
segment_ids: list[str] | None = None,
|
||||
only_parent_chunks: bool = False,
|
||||
) -> list[DocumentSegmentSummary]:
|
||||
"""
|
||||
Generate summaries for all segments in a document including vectorization.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: DatasetDocument to generate summaries for
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_ids: Optional list of specific segment IDs to process
|
||||
only_parent_chunks: If True, only process parent chunks (for parent-child mode)
|
||||
|
||||
Returns:
|
||||
List of created DocumentSegmentSummary instances
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
return []
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info("Summary index is disabled for dataset %s", dataset.id)
|
||||
return []
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s",
|
||||
document.id,
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
only_parent_chunks,
|
||||
)
|
||||
|
||||
# Query segments (only enabled segments)
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True, # Only generate summaries for enabled segments
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegment.id.in_(segment_ids))
|
||||
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return []
|
||||
|
||||
summary_records = []
|
||||
|
||||
for segment in segments:
|
||||
# For parent-child mode, only process parent chunks
|
||||
# In parent-child mode, all DocumentSegments are parent chunks,
|
||||
# so we process all of them. Child chunks are stored in ChildChunk table
|
||||
# and are not DocumentSegments, so they won't be in the segments list.
|
||||
# This check is mainly for clarity and future-proofing.
|
||||
if only_parent_chunks:
|
||||
# In parent-child mode, all segments in the query are parent chunks
|
||||
# Child chunks are not DocumentSegments, so they won't appear here
|
||||
# We can process all segments
|
||||
pass
|
||||
|
||||
try:
|
||||
summary_record = SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, summary_index_setting
|
||||
)
|
||||
summary_records.append(summary_record)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Continue with other segments
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Completed summary generation for document %s: %s summaries generated and vectorized",
|
||||
document.id,
|
||||
len(summary_records),
|
||||
)
|
||||
return summary_records
|
||||
|
||||
@staticmethod
|
||||
def disable_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
disabled_by: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Disable summary records and remove vectors from vector database for segments.
|
||||
Unlike delete, this preserves the summary records but marks them as disabled.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to disable summaries for. If None, disable all.
|
||||
disabled_by: User ID who disabled the summaries
|
||||
"""
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=True, # Only disable enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Disabling %s summary records for dataset %s, segment_ids: %s",
|
||||
len(summaries),
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
)
|
||||
|
||||
# Remove from vector database (but keep records)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids(summary_node_ids)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove summary vectors: %s", str(e))
|
||||
|
||||
# Disable summary records (don't delete)
|
||||
now = naive_utc_now()
|
||||
for summary in summaries:
|
||||
summary.enabled = False
|
||||
summary.disabled_at = now
|
||||
summary.disabled_by = disabled_by
|
||||
db.session.add(summary)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def enable_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Enable summary records and re-add vectors to vector database for segments.
|
||||
|
||||
Note: This method enables summaries based on chunk status, not summary_index_setting.enable.
|
||||
The summary_index_setting.enable flag only controls automatic generation,
|
||||
not whether existing summaries can be used.
|
||||
Summary.enabled should always be kept in sync with chunk.enabled.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
|
||||
"""
|
||||
# Only enable summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=False, # Only enable disabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Enabling %s summary records for dataset %s, segment_ids: %s",
|
||||
len(summaries),
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
)
|
||||
|
||||
# Re-vectorize and re-add to vector database
|
||||
enabled_count = 0
|
||||
for summary in summaries:
|
||||
# Get the original segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(
|
||||
id=summary.chunk_id,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Summary.enabled stays in sync with chunk.enabled, only enable summary if the associated chunk is enabled.
|
||||
if not segment or not segment.enabled or segment.status != "completed":
|
||||
continue
|
||||
|
||||
if not summary.summary_content:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Re-vectorize summary
|
||||
SummaryIndexService.vectorize_summary(summary, segment, dataset)
|
||||
|
||||
# Enable summary record
|
||||
summary.enabled = True
|
||||
summary.disabled_at = None
|
||||
summary.disabled_by = None
|
||||
db.session.add(summary)
|
||||
enabled_count += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to re-vectorize summary %s", summary.id)
|
||||
# Keep it disabled if vectorization fails
|
||||
continue
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def delete_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Delete summary records and vectors for segments (used only for actual deletion scenarios).
|
||||
For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
|
||||
"""
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
# Delete from vector database
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids(summary_node_ids)
|
||||
|
||||
# Delete summary records
|
||||
for summary in summaries:
|
||||
db.session.delete(summary)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def update_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
) -> DocumentSegmentSummary | None:
|
||||
"""
|
||||
Update summary for a segment and re-vectorize it.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to update summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: New summary content
|
||||
|
||||
Returns:
|
||||
Updated DocumentSegmentSummary instance, or None if summary index is not enabled
|
||||
"""
|
||||
# Only update summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return None
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return None
|
||||
|
||||
# Skip qa_model documents
|
||||
if segment.document and segment.document.doc_form == "qa_model":
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find existing summary record
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
# Update existing summary
|
||||
old_summary_node_id = summary_record.summary_index_node_id
|
||||
|
||||
# Update summary content
|
||||
summary_record.summary_content = summary_content
|
||||
summary_record.status = "generating"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
|
||||
# Delete old vector if exists
|
||||
if old_summary_node_id:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([old_summary_node_id])
|
||||
|
||||
# Re-vectorize summary
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
else:
|
||||
# Create new summary record if doesn't exist
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
)
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
db.session.commit()
|
||||
logger.info("Successfully created and vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = str(e)
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
raise
|
||||
@ -1,183 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from models.account import Account
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo
|
||||
|
||||
|
||||
class WorkflowCollaborationService:
|
||||
def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None:
|
||||
self._repository = repository
|
||||
self._socketio = socketio
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(repository={self._repository})"
|
||||
|
||||
def save_session(self, sid: str, user: Account) -> None:
|
||||
self._socketio.save_session(
|
||||
sid,
|
||||
{
|
||||
"user_id": user.id,
|
||||
"username": user.name,
|
||||
"avatar": user.avatar,
|
||||
},
|
||||
)
|
||||
|
||||
def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None:
|
||||
session = self._socketio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
session_info: WorkflowSessionInfo = {
|
||||
"user_id": str(user_id),
|
||||
"username": str(session.get("username", "Unknown")),
|
||||
"avatar": session.get("avatar"),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()),
|
||||
}
|
||||
|
||||
self._repository.set_session_info(workflow_id, session_info)
|
||||
|
||||
leader_sid = self.get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid
|
||||
|
||||
self._socketio.enter_room(sid, workflow_id)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return str(user_id), is_leader
|
||||
|
||||
def disconnect_session(self, sid: str) -> None:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self._repository.delete_session(workflow_id, sid)
|
||||
|
||||
self.handle_leader_disconnect(workflow_id, sid)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
user_id = mapping["user_id"]
|
||||
self._repository.refresh_session_state(workflow_id, sid)
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
self._socketio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}, 200
|
||||
|
||||
def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self._repository.refresh_session_state(workflow_id, sid)
|
||||
|
||||
self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}, 200
|
||||
|
||||
def get_or_set_leader(self, workflow_id: str, sid: str) -> str:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
|
||||
if current_leader:
|
||||
if self.is_session_active(workflow_id, current_leader):
|
||||
return current_leader
|
||||
self._repository.delete_session(workflow_id, current_leader)
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
was_set = self._repository.set_leader_if_absent(workflow_id, sid)
|
||||
|
||||
if was_set:
|
||||
if current_leader:
|
||||
self.broadcast_leader_change(workflow_id, sid)
|
||||
return sid
|
||||
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if current_leader:
|
||||
return current_leader
|
||||
|
||||
return sid
|
||||
|
||||
def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if not current_leader:
|
||||
return
|
||||
|
||||
if current_leader != disconnected_sid:
|
||||
return
|
||||
|
||||
session_sids = self._repository.get_session_sids(workflow_id)
|
||||
if session_sids:
|
||||
new_leader_sid = session_sids[0]
|
||||
self._repository.set_leader(workflow_id, new_leader_sid)
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str) -> None:
|
||||
for sid in self._repository.get_session_sids(workflow_id):
|
||||
try:
|
||||
is_leader = sid == new_leader_sid
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
except Exception:
|
||||
logging.exception("Failed to emit leader status to session %s", sid)
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
return self._repository.get_current_leader(workflow_id)
|
||||
|
||||
def broadcast_online_users(self, workflow_id: str) -> None:
|
||||
users = self._repository.list_sessions(workflow_id)
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
leader_sid = self.get_current_leader(workflow_id)
|
||||
|
||||
self._socketio.emit(
|
||||
"online_users",
|
||||
{"workflow_id": workflow_id, "users": users, "leader": leader_sid},
|
||||
room=workflow_id,
|
||||
)
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
self._repository.refresh_session_state(workflow_id, sid)
|
||||
|
||||
def is_session_active(self, workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self._socketio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not self._repository.session_exists(workflow_id, sid):
|
||||
return False
|
||||
|
||||
if not self._repository.sid_mapping_exists(sid):
|
||||
return False
|
||||
|
||||
return True
|
||||
@ -1,345 +0,0 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from models import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
|
||||
from models.account import Account
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentService:
|
||||
"""Service for managing workflow comments."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_content(content: str) -> None:
|
||||
if len(content.strip()) == 0:
|
||||
raise ValueError("Comment content cannot be empty")
|
||||
|
||||
if len(content) > 1000:
|
||||
raise ValueError("Comment content cannot exceed 1000 characters")
|
||||
|
||||
@staticmethod
|
||||
def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]:
|
||||
"""Get all comments for a workflow."""
|
||||
with Session(db.engine) as session:
|
||||
# Get all comments with eager loading
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
|
||||
.order_by(desc(WorkflowComment.created_at))
|
||||
)
|
||||
|
||||
comments = session.scalars(stmt).all()
|
||||
|
||||
# Batch preload all Account objects to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, comments)
|
||||
|
||||
return comments
|
||||
|
||||
@staticmethod
|
||||
def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None:
|
||||
"""Batch preload Account objects for comments, replies, and mentions."""
|
||||
# Collect all user IDs
|
||||
user_ids: set[str] = set()
|
||||
for comment in comments:
|
||||
user_ids.add(comment.created_by)
|
||||
if comment.resolved_by:
|
||||
user_ids.add(comment.resolved_by)
|
||||
user_ids.update(reply.created_by for reply in comment.replies)
|
||||
user_ids.update(mention.mentioned_user_id for mention in comment.mentions)
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# Batch query all accounts
|
||||
accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
|
||||
account_map = {str(account.id): account for account in accounts}
|
||||
|
||||
# Cache accounts on objects
|
||||
for comment in comments:
|
||||
comment.cache_created_by_account(account_map.get(comment.created_by))
|
||||
comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None)
|
||||
for reply in comment.replies:
|
||||
reply.cache_created_by_account(account_map.get(reply.created_by))
|
||||
for mention in comment.mentions:
|
||||
mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id))
|
||||
|
||||
@staticmethod
|
||||
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment:
|
||||
"""Get a specific comment."""
|
||||
|
||||
def _get_comment(session: Session) -> WorkflowComment:
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Preload accounts to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, [comment])
|
||||
|
||||
return comment
|
||||
|
||||
if session is not None:
|
||||
return _get_comment(session)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return _get_comment(session)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
created_by: str,
|
||||
content: str,
|
||||
position_x: float,
|
||||
position_y: float,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Create a new workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
comment = WorkflowComment(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
position_x=position_x,
|
||||
position_y=position_y,
|
||||
content=content,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
session.add(comment)
|
||||
session.flush() # Get the comment ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id in mentioned_user_ids:
|
||||
if isinstance(user_id, str) and uuid_value(user_id):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention, not reply mention
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Return only what we need - id and created_at
|
||||
return {"id": comment.id, "created_at": comment.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
comment_id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
position_x: float | None = None,
|
||||
position_y: float | None = None,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Update a workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get comment with validation
|
||||
stmt = select(WorkflowComment).where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Only the creator can update the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can update it")
|
||||
|
||||
# Update comment fields
|
||||
comment.content = content
|
||||
if position_x is not None:
|
||||
comment.position_x = position_x
|
||||
if position_y is not None:
|
||||
comment.position_y = position_y
|
||||
|
||||
# Update mentions - first remove existing mentions for this comment only (not replies)
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(
|
||||
WorkflowCommentMention.comment_id == comment.id,
|
||||
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
|
||||
)
|
||||
).all()
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add new mentions
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id_str in mentioned_user_ids:
|
||||
if isinstance(user_id_str, str) and uuid_value(user_id_str):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention
|
||||
mentioned_user_id=user_id_str,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"id": comment.id, "updated_at": comment.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
|
||||
"""Delete a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
|
||||
# Only the creator can delete the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can delete it")
|
||||
|
||||
# Delete associated mentions (both comment and reply mentions)
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Delete associated replies
|
||||
replies = session.scalars(
|
||||
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
|
||||
).all()
|
||||
for reply in replies:
|
||||
session.delete(reply)
|
||||
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
|
||||
"""Resolve a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
if comment.resolved:
|
||||
return comment
|
||||
|
||||
comment.resolved = True
|
||||
comment.resolved_at = naive_utc_now()
|
||||
comment.resolved_by = user_id
|
||||
session.commit()
|
||||
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def create_reply(
|
||||
comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None
|
||||
) -> dict:
|
||||
"""Add a reply to a workflow comment."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Check if comment exists
|
||||
comment = session.get(WorkflowComment, comment_id)
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
|
||||
|
||||
session.add(reply)
|
||||
session.flush() # Get the reply ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id in mentioned_user_ids:
|
||||
if isinstance(user_id, str) and uuid_value(user_id):
|
||||
# Create mention linking to specific reply
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
|
||||
return {"id": reply.id, "created_at": reply.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict:
|
||||
"""Update a comment reply."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can update the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can update it")
|
||||
|
||||
reply.content = content
|
||||
|
||||
# Update mentions - first remove existing mentions for this reply
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
|
||||
).all()
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add mentions
|
||||
mentioned_user_ids = mentioned_user_ids or []
|
||||
for user_id_str in mentioned_user_ids:
|
||||
if isinstance(user_id_str, str) and uuid_value(user_id_str):
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
session.commit()
|
||||
session.refresh(reply) # Refresh to get updated timestamp
|
||||
|
||||
return {"id": reply.id, "updated_at": reply.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_reply(reply_id: str, user_id: str) -> None:
|
||||
"""Delete a comment reply."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can delete the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can delete it")
|
||||
|
||||
# Delete associated mentions first
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
session.delete(reply)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
|
||||
"""Validate that a comment belongs to the specified tenant and app."""
|
||||
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)
|
||||
@ -200,17 +200,15 @@ class WorkflowService:
|
||||
account: Account,
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
force_upload: bool = False,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
:param force_upload: Skip hash validation when True (for restore operations)
|
||||
:raises WorkflowHashNotEqualError
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if workflow and workflow.unique_hash != unique_hash and not force_upload:
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# validate features structure
|
||||
@ -251,78 +249,6 @@ class WorkflowService:
|
||||
# return draft workflow
|
||||
return workflow
|
||||
|
||||
def update_draft_workflow_environment_variables(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
environment_variables: Sequence[VariableBase],
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow environment variables
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def update_draft_workflow_conversation_variables(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow conversation variables
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
workflow.conversation_variables = conversation_variables
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def update_draft_workflow_features(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
features: dict,
|
||||
account: Account,
|
||||
):
|
||||
"""
|
||||
Update draft workflow features
|
||||
"""
|
||||
# fetch draft workflow by app_model
|
||||
workflow = self.get_draft_workflow(app_model=app_model)
|
||||
|
||||
if not workflow:
|
||||
raise ValueError("No draft workflow found.")
|
||||
|
||||
# validate features structure
|
||||
self.validate_features_structure(app_model=app_model, features=features)
|
||||
|
||||
workflow.features = json.dumps(features)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = naive_utc_now()
|
||||
|
||||
# commit db session changes
|
||||
db.session.commit()
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
|
||||
@ -117,6 +117,19 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# Enable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
|
||||
@ -42,6 +42,7 @@ def delete_segment_from_index_task(
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
# For actual deletion, we should delete summaries (not just disable them)
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
@ -49,6 +50,7 @@ def delete_segment_from_index_task(
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
delete_summaries=True, # Actually delete summaries when segment is deleted
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
|
||||
@ -53,6 +53,18 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
# Disable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
disabled_by=segment.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
@ -58,12 +58,26 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
return
|
||||
|
||||
try:
|
||||
# Disable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
# Get disabled_by from first segment (they should all have the same disabled_by)
|
||||
disabled_by = segments[0].disabled_by if segments else None
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for segments: %s", str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = (
|
||||
db.session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids_list))
|
||||
.all()
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
|
||||
@ -14,6 +14,7 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -100,6 +101,69 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
)
|
||||
if document.indexing_status == "completed" and document.doc_form != "qa_model":
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: status=%s, doc_form=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
|
||||
@ -103,6 +103,17 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
|
||||
# Enable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -108,6 +108,18 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# Enable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
112
api/tasks/generate_summary_index_task.py
Normal file
112
api/tasks/generate_summary_index_task.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Async task for generating summary indexes."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
|
||||
"""
|
||||
Async generate summary index for document segments.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_id: Document ID
|
||||
segment_ids: Optional list of specific segment IDs to process. If None, process all segments.
|
||||
|
||||
Usage:
|
||||
generate_summary_index_task.delay(dataset_id, document_id)
|
||||
generate_summary_index_task.delay(dataset_id, document_id, segment_ids)
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start generating summary index for document {document_id} in dataset {dataset_id}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
if not document:
|
||||
logger.error(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
# Generate summaries
|
||||
summary_records = SummaryIndexService.generate_summaries_for_document(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_ids=segment_ids,
|
||||
only_parent_chunks=only_parent_chunks,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index generation completed for document {document_id}: "
|
||||
f"{len(summary_records)} summaries generated, latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document_id)
|
||||
# Update document segments with error status if needed
|
||||
if segment_ids:
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.error: f"Summary generation failed: {str(e)}",
|
||||
},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
db.session.close()
|
||||
221
api/tasks/regenerate_summary_index_task.py
Normal file
221
api/tasks/regenerate_summary_index_task.py
Normal file
@ -0,0 +1,221 @@
|
||||
"""Task for regenerating summary indexes when dataset settings change."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def regenerate_summary_index_task(
|
||||
dataset_id: str,
|
||||
regenerate_reason: str = "summary_model_changed",
|
||||
regenerate_vectors_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Regenerate summary indexes for all documents in a dataset.
|
||||
|
||||
This task is triggered when:
|
||||
1. summary_index_setting model changes (regenerate_reason="summary_model_changed")
|
||||
- Regenerates summary content and vectors for all existing summaries
|
||||
2. embedding_model changes (regenerate_reason="embedding_model_changed")
|
||||
- Only regenerates vectors for existing summaries (keeps summary content)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed")
|
||||
regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Only regenerate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary regeneration for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Get all documents with completed indexing status
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not dataset_documents:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"No documents found for summary regeneration in dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Found %s documents for summary regeneration in dataset %s",
|
||||
len(dataset_documents),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
total_segments_processed = 0
|
||||
total_segments_failed = 0
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get all segments with existing summaries
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Regenerating summaries for %s segments in document %s",
|
||||
len(segments),
|
||||
dataset_document.id,
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
try:
|
||||
# Get existing summary record
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
chunk_id=segment.id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not summary_record:
|
||||
logger.warning("Summary record not found for segment %s, skipping", segment.id)
|
||||
continue
|
||||
|
||||
if regenerate_vectors_only:
|
||||
# Only regenerate vectors (for embedding_model change)
|
||||
# Delete old vector
|
||||
if summary_record.summary_index_node_id:
|
||||
try:
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([summary_record.summary_index_node_id])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete old summary vector for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Re-vectorize with new embedding model
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
db.session.commit()
|
||||
else:
|
||||
# Regenerate both summary content and vectors (for summary_model change)
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
db.session.commit()
|
||||
|
||||
total_segments_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to regenerate summary for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Regeneration failed: {str(e)}"
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process document %s for summary regeneration: %s",
|
||||
dataset_document.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index regeneration completed for dataset {dataset_id}: "
|
||||
f"{total_segments_processed} segments processed successfully, "
|
||||
f"{total_segments_failed} segments failed, "
|
||||
f"total documents: {len(dataset_documents)}, "
|
||||
f"latency: {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Regenerate summary index failed for dataset %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
@ -47,6 +47,21 @@ def remove_document_from_index_task(document_id: str):
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
|
||||
# Disable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=document.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
|
||||
@ -83,30 +83,7 @@
|
||||
<p class="content1">Dear {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
|
||||
<p class="content2">Click the button below to log in to Dify and join the workspace.</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Login Here</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
|
||||
<p class="content2">Best regards,</p>
|
||||
<p class="content2">Dify Team</p>
|
||||
</div>
|
||||
|
||||
@ -83,30 +83,7 @@
|
||||
<p class="content1">尊敬的 {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
|
||||
<p class="content2">点击下方按钮即可登录 Dify 并且加入空间。</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">在此登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
|
||||
<p class="content2">此致,</p>
|
||||
<p class="content2">Dify 团队</p>
|
||||
</div>
|
||||
|
||||
@ -115,30 +115,7 @@
|
||||
We noticed you tried to sign up, but this email is already registered with an existing account.
|
||||
|
||||
Please log in here: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Log In</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">Log In</a>
|
||||
<p class="description">
|
||||
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
|
||||
class="reset-btn">Reset Password</a>
|
||||
|
||||
@ -115,30 +115,7 @@
|
||||
我们注意到您尝试注册,但此电子邮件已注册。
|
||||
|
||||
请在此登录: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">登录</a>
|
||||
<p class="description">
|
||||
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
|
||||
</p>
|
||||
|
||||
@ -92,34 +92,12 @@
|
||||
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
|
||||
create, and collaborate to build and operate AI applications.</p>
|
||||
<p class="content2">Click the button below to log in to {{application_title}} and join the workspace.</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Login Here</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none"
|
||||
class="button" href="{{ url }}">Login Here</a></p>
|
||||
<p class="content2">Best regards,</p>
|
||||
<p class="content2">{{application_title}} Team</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
@ -81,30 +81,7 @@
|
||||
<p class="content1">尊敬的 {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
|
||||
<p class="content2">点击下方按钮即可登录 {{application_title}} 并且加入空间。</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">在此登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
|
||||
<p class="content2">此致,</p>
|
||||
<p class="content2">{{application_title}} 团队</p>
|
||||
</div>
|
||||
|
||||
@ -111,30 +111,7 @@
|
||||
We noticed you tried to sign up, but this email is already registered with an existing account.
|
||||
|
||||
Please log in here: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Log In</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">Log In</a>
|
||||
<p class="description">
|
||||
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
|
||||
class="reset-btn">Reset Password</a>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user