mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 19:55:06 +08:00
Compare commits
208 Commits
deploy/dev
...
feat/hitl-
| Author | SHA1 | Date | |
|---|---|---|---|
| 5e644315e4 | |||
| e3a22e5027 | |||
| 977cef5a22 | |||
| b3902374ac | |||
| 1353eec9ca | |||
| 3b225c01da | |||
| 72ce6ca437 | |||
| 269c85d5a3 | |||
| b0545635b8 | |||
| 13d648cf7b | |||
| 73343f03c1 | |||
| d401a29bd9 | |||
| 3bf8f19874 | |||
| 51a7ddba81 | |||
| bd634b165d | |||
| a298140d8f | |||
| 3db3a18eff | |||
| 91c35c2f0a | |||
| 61c7fdc614 | |||
| 88c2483192 | |||
| ca58055a39 | |||
| 368e38d593 | |||
| 7463bc9199 | |||
| 35a707199f | |||
| dfb25df5ec | |||
| cdc24696e4 | |||
| 8f6a0e2d8e | |||
| 20ba0de47d | |||
| f13ace9dd9 | |||
| 23f5427349 | |||
| 7deaab116a | |||
| b6db6d9305 | |||
| 2fa688cefd | |||
| b24fafe901 | |||
| c837def205 | |||
| 1a7eac192c | |||
| d21cf87bb6 | |||
| ef41325ad1 | |||
| 57c33d5869 | |||
| 4484261023 | |||
| 72bc396646 | |||
| 68885afac6 | |||
| 18e57096d2 | |||
| b6c6d52725 | |||
| bd104557ef | |||
| 55b18bb79e | |||
| 3f8f158b04 | |||
| 471d14f882 | |||
| a280df2c07 | |||
| b479a36273 | |||
| 9136bf48f5 | |||
| ac70069847 | |||
| 27da1e72eb | |||
| 467119d186 | |||
| 1cdbfa2539 | |||
| d3299db915 | |||
| ebb816b90b | |||
| 0b794ad9cc | |||
| a19c0023f9 | |||
| c7fa7009b1 | |||
| 3e1a96ad64 | |||
| d500a631f0 | |||
| 607e77e0a7 | |||
| 46ec24cf8a | |||
| f654a7f704 | |||
| cf9b72d574 | |||
| 5a92bbdab5 | |||
| d0a713e117 | |||
| 9ec127ea8f | |||
| d478822fe1 | |||
| 4eac271adc | |||
| 0f13b6b26d | |||
| 52ce46c364 | |||
| f0f1ae0b49 | |||
| 0c69466b0f | |||
| 44a688cb81 | |||
| 0e2b59d661 | |||
| 501c3bcc94 | |||
| bf6a2c22eb | |||
| 8c0365c71e | |||
| c0a4f3b715 | |||
| d308c34842 | |||
| 68615eec04 | |||
| eca3e23af0 | |||
| c716c4ccbe | |||
| 32366774a1 | |||
| 640661983a | |||
| f528f2eafc | |||
| 0994953728 | |||
| d80167d9ec | |||
| 2969a77b15 | |||
| 062896cb9c | |||
| 5a1e6269d5 | |||
| e744d4de80 | |||
| be0f493e61 | |||
| b8d4f60782 | |||
| 8b65b689f7 | |||
| 8b9846f52b | |||
| 1c4c1b5cb1 | |||
| ec0c144eb2 | |||
| afddc56bb4 | |||
| 57dde4a4d8 | |||
| 3c0fd213bf | |||
| ddfd1cb1f5 | |||
| 138922f3f4 | |||
| aa9d0bd655 | |||
| ccef15aafa | |||
| 202f924d99 | |||
| 882d2f5b4c | |||
| 93ec42344c | |||
| 330585ec9b | |||
| 8f2660cea8 | |||
| 87f6ac78c5 | |||
| 6c7d58aa11 | |||
| 8bca22a94b | |||
| 17aa11da79 | |||
| 63bbcff496 | |||
| f71a632cdc | |||
| 4e3bded902 | |||
| e75bbdbc0d | |||
| 331a5edbf9 | |||
| 6afc99a5ad | |||
| a4e2ef6b0c | |||
| 3632b473df | |||
| e35dd14c59 | |||
| 495f901798 | |||
| 8703515153 | |||
| be3c6da654 | |||
| d51db3afb3 | |||
| 22683fba3f | |||
| ed16265eee | |||
| 463ea14d44 | |||
| 527736b8e4 | |||
| f22dcee6d9 | |||
| 9bdd7e5465 | |||
| ad0e79372f | |||
| a362114486 | |||
| 79ab253c26 | |||
| 783b78cc0a | |||
| b1e123c3aa | |||
| 84709a7941 | |||
| ec07636ce9 | |||
| 4d4c8b21ac | |||
| ec8754173f | |||
| 7920b89714 | |||
| 19e152fd0c | |||
| 076a8ecff4 | |||
| 40591b2196 | |||
| 5156b8f9c9 | |||
| 286ab0d468 | |||
| 71a511a470 | |||
| 6b11973151 | |||
| 949a894f03 | |||
| 305b5da764 | |||
| fda19d3f0e | |||
| e5a2172a85 | |||
| 2d89d59d74 | |||
| 4f56acd432 | |||
| c1b7412465 | |||
| 62b9a20115 | |||
| 5392401e60 | |||
| baa77d3cda | |||
| a41176b66d | |||
| 465e978209 | |||
| a9ea8cfd1c | |||
| c771f4dbc7 | |||
| ebbed8f863 | |||
| bdf1e9ed3b | |||
| 36acd0b9dd | |||
| a4049e1ea7 | |||
| 114dfe038c | |||
| 81f6344aaa | |||
| 89963ecf59 | |||
| a18bcf3957 | |||
| c28720529e | |||
| 242826013e | |||
| 05453cb22f | |||
| da211d3009 | |||
| f8a249de03 | |||
| e2e5dedceb | |||
| bd7ba85471 | |||
| 7d5f6bc255 | |||
| fcc8789cc3 | |||
| 6da1a48cad | |||
| 465ff7838a | |||
| a4bf493343 | |||
| c27f20b4b7 | |||
| a9e6140dc6 | |||
| 792f28451c | |||
| 82530df38f | |||
| ce8325c83c | |||
| 3ed561d943 | |||
| bb8d54c48b | |||
| 922aeb7c21 | |||
| 177be06d09 | |||
| 736ec55f86 | |||
| ab373197f9 | |||
| a6d2392c6c | |||
| 3371989572 | |||
| f04daf056d | |||
| fb6c8fa01f | |||
| b02199145e | |||
| e257455c9c | |||
| 63af5305e6 | |||
| 95b88a0621 | |||
| 1ab02c6e9a | |||
| 1099ab5d91 | |||
| 6485adae35 |
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@ -8,6 +8,7 @@ on:
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-frontend"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
@ -71,6 +71,8 @@ def create_app() -> DifyApp:
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
# Initialize Flask context capture for workflow execution
|
||||
from context.flask_app_context import init_flask_context
|
||||
from extensions import (
|
||||
ext_app_metrics,
|
||||
ext_blueprints,
|
||||
@ -100,6 +102,8 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
)
|
||||
|
||||
init_flask_context()
|
||||
|
||||
extensions = [
|
||||
ext_timezone,
|
||||
ext_logging,
|
||||
|
||||
@ -862,8 +862,27 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
|
||||
|
||||
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
|
||||
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
|
||||
@click.option(
|
||||
"--before-days",
|
||||
"--days",
|
||||
default=30,
|
||||
show_default=True,
|
||||
type=click.IntRange(min=0),
|
||||
help="Delete workflow runs created before N days ago.",
|
||||
)
|
||||
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
|
||||
@click.option(
|
||||
"--from-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Lower bound in days ago (older). Must be paired with --to-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--to-days-ago",
|
||||
default=None,
|
||||
type=click.IntRange(min=0),
|
||||
help="Upper bound in days ago (newer). Must be paired with --from-days-ago.",
|
||||
)
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
@ -882,8 +901,10 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
|
||||
help="Preview cleanup results without deleting any workflow run data.",
|
||||
)
|
||||
def clean_workflow_runs(
|
||||
days: int,
|
||||
before_days: int,
|
||||
batch_size: int,
|
||||
from_days_ago: int | None,
|
||||
to_days_ago: int | None,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime | None,
|
||||
dry_run: bool,
|
||||
@ -894,11 +915,24 @@ def clean_workflow_runs(
|
||||
if (start_from is None) ^ (end_before is None):
|
||||
raise click.UsageError("--start-from and --end-before must be provided together.")
|
||||
|
||||
if (from_days_ago is None) ^ (to_days_ago is None):
|
||||
raise click.UsageError("--from-days-ago and --to-days-ago must be provided together.")
|
||||
|
||||
if from_days_ago is not None and to_days_ago is not None:
|
||||
if start_from or end_before:
|
||||
raise click.UsageError("Choose either day offsets or explicit dates, not both.")
|
||||
if from_days_ago <= to_days_ago:
|
||||
raise click.UsageError("--from-days-ago must be greater than --to-days-ago.")
|
||||
now = datetime.datetime.now()
|
||||
start_from = now - datetime.timedelta(days=from_days_ago)
|
||||
end_before = now - datetime.timedelta(days=to_days_ago)
|
||||
before_days = 0
|
||||
|
||||
start_time = datetime.datetime.now(datetime.UTC)
|
||||
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
|
||||
|
||||
WorkflowRunCleanup(
|
||||
days=days,
|
||||
days=before_days,
|
||||
batch_size=batch_size,
|
||||
start_from=start_from,
|
||||
end_before=end_before,
|
||||
|
||||
74
api/context/__init__.py
Normal file
74
api/context/__init__.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Core Context - Framework-agnostic context management.
|
||||
|
||||
This module provides context management that is independent of any specific
|
||||
web framework. Framework-specific implementations register their context
|
||||
capture functions at application initialization time.
|
||||
|
||||
This ensures the workflow layer remains completely decoupled from Flask
|
||||
or any other web framework.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Callable
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
)
|
||||
|
||||
# Global capturer function - set by framework-specific modules
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""
|
||||
Register a context capture function.
|
||||
|
||||
This should be called by framework-specific modules (e.g., Flask)
|
||||
during application initialization.
|
||||
|
||||
Args:
|
||||
capturer: Function that captures current context and returns IExecutionContext
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context.
|
||||
|
||||
This function uses the registered context capturer. If no capturer
|
||||
is registered, it returns a minimal context with only contextvars
|
||||
(suitable for non-framework environments like tests or standalone scripts).
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
if _capturer is None:
|
||||
# No framework registered - return minimal context
|
||||
return ExecutionContext(
|
||||
app_context=NullAppContext(),
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
return _capturer()
|
||||
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""
|
||||
Reset the context capturer.
|
||||
|
||||
This is primarily useful for testing to ensure a clean state.
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"capture_current_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
198
api/context/flask_app_context.py
Normal file
198
api/context/flask_app_context.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""
|
||||
Flask App Context - Flask implementation of AppContext interface.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
from context import register_context_capturer
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
IExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskAppContext(AppContext):
|
||||
"""
|
||||
Flask implementation of AppContext.
|
||||
|
||||
This adapts Flask's app context to the AppContext interface.
|
||||
"""
|
||||
|
||||
def __init__(self, flask_app: Flask) -> None:
|
||||
"""
|
||||
Initialize Flask app context.
|
||||
|
||||
Args:
|
||||
flask_app: The Flask application instance
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
yield
|
||||
|
||||
@property
|
||||
def flask_app(self) -> Flask:
|
||||
"""Get the underlying Flask app instance."""
|
||||
return self._flask_app
|
||||
|
||||
|
||||
def capture_flask_context(user: Any = None) -> IExecutionContext:
|
||||
"""
|
||||
Capture current Flask execution context.
|
||||
|
||||
This function captures the Flask app context and contextvars from the
|
||||
current environment. It should be called from within a Flask request or
|
||||
app context.
|
||||
|
||||
Args:
|
||||
user: Optional user object to include in context
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured Flask context
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called outside Flask context
|
||||
"""
|
||||
# Get Flask app instance
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Save current user if available
|
||||
saved_user = user
|
||||
if saved_user is None:
|
||||
# Check for user in g (flask-login)
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Capture contextvars
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
return FlaskExecutionContext(
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
user=saved_user,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskExecutionContext:
|
||||
"""
|
||||
Flask-specific execution context.
|
||||
|
||||
This is a specialized version of ExecutionContext that includes Flask app
|
||||
context. It provides the same interface as ExecutionContext but with
|
||||
Flask-specific implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Flask execution context.
|
||||
|
||||
Args:
|
||||
flask_app: Flask application instance
|
||||
context_vars: Python contextvars
|
||||
user: Optional user object
|
||||
"""
|
||||
self._app_context = FlaskAppContext(flask_app)
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
"""Get Flask app context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
yield
|
||||
|
||||
|
||||
def init_flask_context() -> None:
|
||||
"""
|
||||
Initialize Flask context capture by registering the capturer.
|
||||
|
||||
This function should be called during Flask application initialization
|
||||
to register the Flask-specific context capturer with the core context module.
|
||||
|
||||
Example:
|
||||
app = Flask(__name__)
|
||||
init_flask_context() # Register Flask context capturer
|
||||
|
||||
Note:
|
||||
This function does not need the app instance as it uses Flask's
|
||||
`current_app` to get the app when capturing context.
|
||||
"""
|
||||
register_context_capturer(capture_flask_context)
|
||||
@ -69,6 +69,13 @@ class ActivateCheckApi(Resource):
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
|
||||
# Check workspace permission
|
||||
if tenant:
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
workspace_name = tenant.name if tenant else None
|
||||
workspace_id = tenant.id if tenant else None
|
||||
invitee_email = data.get("email") if data else None
|
||||
|
||||
@ -146,7 +146,6 @@ class DatasetUpdatePayload(BaseModel):
|
||||
embedding_model: str | None = None
|
||||
embedding_model_provider: str | None = None
|
||||
retrieval_model: dict[str, Any] | None = None
|
||||
summary_index_setting: dict[str, Any] | None = None
|
||||
partial_member_list: list[dict[str, str]] | None = None
|
||||
external_retrieval_model: dict[str, Any] | None = None
|
||||
external_knowledge_id: str | None = None
|
||||
|
||||
@ -39,10 +39,9 @@ from fields.document_fields import (
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import DatasetProcessRule, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DocumentPipelineExecutionLog, DocumentSegmentSummary
|
||||
from models.dataset import DocumentPipelineExecutionLog
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
from ..app.error import (
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
@ -105,10 +104,6 @@ class DocumentRenamePayload(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class GenerateSummaryPayload(BaseModel):
|
||||
document_list: list[str]
|
||||
|
||||
|
||||
class DocumentDatasetListParam(BaseModel):
|
||||
page: int = Field(1, title="Page", description="Page number.")
|
||||
limit: int = Field(20, title="Limit", description="Page size.")
|
||||
@ -125,7 +120,6 @@ register_schema_models(
|
||||
RetrievalModel,
|
||||
DocumentRetryPayload,
|
||||
DocumentRenamePayload,
|
||||
GenerateSummaryPayload,
|
||||
)
|
||||
|
||||
|
||||
@ -312,94 +306,6 @@ class DatasetDocumentListApi(Resource):
|
||||
|
||||
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
documents = paginated_documents.items
|
||||
|
||||
# Check if dataset has summary index enabled
|
||||
has_summary_index = dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True
|
||||
|
||||
# Filter documents that need summary calculation
|
||||
documents_need_summary = [doc for doc in documents if doc.need_summary is True]
|
||||
document_ids_need_summary = [str(doc.id) for doc in documents_need_summary]
|
||||
|
||||
# Calculate summary_index_status for documents that need summary (only if dataset summary index is enabled)
|
||||
summary_status_map = {}
|
||||
if has_summary_index and document_ids_need_summary:
|
||||
# Get all segments for these documents (excluding qa_model and re_segment)
|
||||
segments = (
|
||||
db.session.query(DocumentSegment.id, DocumentSegment.document_id)
|
||||
.where(
|
||||
DocumentSegment.document_id.in_(document_ids_need_summary),
|
||||
DocumentSegment.status != "re_segment",
|
||||
DocumentSegment.tenant_id == current_tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group segments by document_id
|
||||
document_segments_map = {}
|
||||
for segment in segments:
|
||||
doc_id = str(segment.document_id)
|
||||
if doc_id not in document_segments_map:
|
||||
document_segments_map[doc_id] = []
|
||||
document_segments_map[doc_id].append(segment.id)
|
||||
|
||||
# Get all summary records for these segments
|
||||
all_segment_ids = [seg.id for seg in segments]
|
||||
summaries = {}
|
||||
if all_segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(all_segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only count enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
summaries = {summary.chunk_id: summary.status for summary in summary_records}
|
||||
|
||||
# Calculate summary_index_status for each document
|
||||
for doc_id in document_ids_need_summary:
|
||||
segment_ids = document_segments_map.get(doc_id, [])
|
||||
if not segment_ids:
|
||||
# No segments, status is "GENERATING" (waiting to generate)
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
continue
|
||||
|
||||
# Count summary statuses for this document's segments
|
||||
status_counts = {"completed": 0, "generating": 0, "error": 0, "not_started": 0}
|
||||
for segment_id in segment_ids:
|
||||
status = summaries.get(segment_id, "not_started")
|
||||
if status in status_counts:
|
||||
status_counts[status] += 1
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
|
||||
total_segments = len(segment_ids)
|
||||
completed_count = status_counts["completed"]
|
||||
generating_count = status_counts["generating"]
|
||||
error_count = status_counts["error"]
|
||||
|
||||
# Determine overall status (only three states: GENERATING, COMPLETED, ERROR)
|
||||
if completed_count == total_segments:
|
||||
summary_status_map[doc_id] = "COMPLETED"
|
||||
elif error_count > 0:
|
||||
# Has errors (even if some are completed or generating)
|
||||
summary_status_map[doc_id] = "ERROR"
|
||||
elif generating_count > 0 or status_counts["not_started"] > 0:
|
||||
# Still generating or not started
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
else:
|
||||
# Default to generating
|
||||
summary_status_map[doc_id] = "GENERATING"
|
||||
|
||||
# Add summary_index_status to each document
|
||||
for document in documents:
|
||||
if has_summary_index and document.need_summary is True:
|
||||
document.summary_index_status = summary_status_map.get(str(document.id), "GENERATING")
|
||||
else:
|
||||
# Return null if summary index is not enabled or document doesn't need summary
|
||||
document.summary_index_status = None
|
||||
|
||||
if fetch:
|
||||
for document in documents:
|
||||
completed_segments = (
|
||||
@ -885,7 +791,6 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
else:
|
||||
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
|
||||
@ -921,7 +826,6 @@ class DocumentApi(DocumentResource):
|
||||
"display_status": document.display_status,
|
||||
"doc_form": document.doc_form,
|
||||
"doc_language": document.doc_language,
|
||||
"need_summary": document.need_summary if document.need_summary is not None else False,
|
||||
}
|
||||
|
||||
return response, 200
|
||||
@ -1289,216 +1193,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
|
||||
"input_data": log.input_data,
|
||||
"datasource_node_id": log.datasource_node_id,
|
||||
}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
|
||||
class DocumentGenerateSummaryApi(Resource):
|
||||
@console_ns.doc("generate_summary_for_documents")
|
||||
@console_ns.doc(description="Generate summary index for documents")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
|
||||
@console_ns.response(200, "Summary generation started successfully")
|
||||
@console_ns.response(400, "Invalid request or dataset configuration")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self, dataset_id):
|
||||
"""
|
||||
Generate summary index for specified documents.
|
||||
|
||||
This endpoint checks if the dataset configuration supports summary generation
|
||||
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
|
||||
then asynchronously generates summary indexes for the provided documents.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Validate request payload
|
||||
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
|
||||
document_list = payload.document_list
|
||||
|
||||
if not document_list:
|
||||
raise ValueError("document_list cannot be empty.")
|
||||
|
||||
# Check if dataset configuration supports summary generation
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
raise ValueError(
|
||||
f"Summary generation is only available for 'high_quality' indexing technique. "
|
||||
f"Current indexing technique: {dataset.indexing_technique}"
|
||||
)
|
||||
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
|
||||
|
||||
# Verify all documents exist and belong to the dataset
|
||||
documents = (
|
||||
db.session.query(Document)
|
||||
.filter(
|
||||
Document.id.in_(document_list),
|
||||
Document.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if len(documents) != len(document_list):
|
||||
found_ids = {doc.id for doc in documents}
|
||||
missing_ids = set(document_list) - found_ids
|
||||
raise NotFound(f"Some documents not found: {list(missing_ids)}")
|
||||
|
||||
# Dispatch async tasks for each document
|
||||
for document in documents:
|
||||
# Skip qa_model documents as they don't generate summaries
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
continue
|
||||
|
||||
# Dispatch async task
|
||||
generate_summary_index_task(dataset_id, document.id)
|
||||
logger.info(
|
||||
"Dispatched summary generation task for document %s in dataset %s",
|
||||
document.id,
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
|
||||
class DocumentSummaryStatusApi(DocumentResource):
|
||||
@console_ns.doc("get_document_summary_status")
|
||||
@console_ns.doc(description="Get summary index generation status for a document")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
|
||||
@console_ns.response(200, "Summary status retrieved successfully")
|
||||
@console_ns.response(404, "Document not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, dataset_id, document_id):
|
||||
"""
|
||||
Get summary index generation status for a document.
|
||||
|
||||
Returns:
|
||||
- total_segments: Total number of segments in the document
|
||||
- summary_status: Dictionary with status counts
|
||||
- completed: Number of summaries completed
|
||||
- generating: Number of summaries being generated
|
||||
- error: Number of summaries with errors
|
||||
- not_started: Number of segments without summary records
|
||||
- summaries: List of summary records with status and content preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
dataset_id = str(dataset_id)
|
||||
document_id = str(document_id)
|
||||
|
||||
# Get document
|
||||
document = self.get_document(dataset_id, document_id)
|
||||
|
||||
# Get dataset
|
||||
dataset = DatasetService.get_dataset(dataset_id)
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
|
||||
# Check permissions
|
||||
try:
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
except services.errors.account.NoPermissionError as e:
|
||||
raise Forbidden(str(e))
|
||||
|
||||
# Get all segments for this document
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.document_id == document_id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_segments = len(segments)
|
||||
|
||||
# Get all summary records for these segments
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
summaries = []
|
||||
if segment_ids:
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.document_id == document_id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create a mapping of chunk_id to summary
|
||||
summary_map = {summary.chunk_id: summary for summary in summaries}
|
||||
|
||||
# Count statuses
|
||||
status_counts = {
|
||||
"completed": 0,
|
||||
"generating": 0,
|
||||
"error": 0,
|
||||
"not_started": 0,
|
||||
}
|
||||
|
||||
summary_list = []
|
||||
for segment in segments:
|
||||
summary = summary_map.get(segment.id)
|
||||
if summary:
|
||||
status = summary.status
|
||||
status_counts[status] = status_counts.get(status, 0) + 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": summary.status,
|
||||
"summary_preview": (
|
||||
summary.summary_content[:100] + "..."
|
||||
if summary.summary_content and len(summary.summary_content) > 100
|
||||
else summary.summary_content
|
||||
),
|
||||
"error": summary.error,
|
||||
"created_at": int(summary.created_at.timestamp()) if summary.created_at else None,
|
||||
"updated_at": int(summary.updated_at.timestamp()) if summary.updated_at else None,
|
||||
}
|
||||
)
|
||||
else:
|
||||
status_counts["not_started"] += 1
|
||||
summary_list.append(
|
||||
{
|
||||
"segment_id": segment.id,
|
||||
"segment_position": segment.position,
|
||||
"status": "not_started",
|
||||
"summary_preview": None,
|
||||
"error": None,
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"total_segments": total_segments,
|
||||
"summary_status": status_counts,
|
||||
"summaries": summary_list,
|
||||
}, 200
|
||||
|
||||
@ -32,7 +32,7 @@ from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import child_chunk_fields, segment_fields
|
||||
from libs.helper import escape_like_pattern
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import ChildChunk, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import ChildChunk, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
|
||||
@ -41,23 +41,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
|
||||
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
|
||||
|
||||
|
||||
def _get_segment_with_summary(segment, dataset_id):
|
||||
"""Helper function to marshal segment and add summary information."""
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
# Query summary for this segment (only enabled summaries)
|
||||
summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
DocumentSegmentSummary.enabled == True, # Only return enabled summaries
|
||||
)
|
||||
.first()
|
||||
)
|
||||
segment_dict["summary"] = summary.summary_content if summary else None
|
||||
return segment_dict
|
||||
|
||||
|
||||
class SegmentListQuery(BaseModel):
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
status: list[str] = Field(default_factory=list)
|
||||
@ -80,7 +63,6 @@ class SegmentUpdatePayload(BaseModel):
|
||||
keywords: list[str] | None = None
|
||||
regenerate_child_chunks: bool = False
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class BatchImportPayload(BaseModel):
|
||||
@ -198,32 +180,8 @@ class DatasetDocumentSegmentListApi(Resource):
|
||||
|
||||
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
|
||||
|
||||
# Query summaries for all segments in this page (batch query for efficiency)
|
||||
segment_ids = [segment.id for segment in segments.items]
|
||||
summaries = {}
|
||||
if segment_ids:
|
||||
summary_records = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id.in_(segment_ids),
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# Only include enabled summaries
|
||||
summaries = {
|
||||
summary.chunk_id: summary.summary_content for summary in summary_records if summary.enabled is True
|
||||
}
|
||||
|
||||
# Add summary to each segment
|
||||
segments_with_summary = []
|
||||
for segment in segments.items:
|
||||
segment_dict = marshal(segment, segment_fields)
|
||||
segment_dict["summary"] = summaries.get(segment.id)
|
||||
segments_with_summary.append(segment_dict)
|
||||
|
||||
response = {
|
||||
"data": segments_with_summary,
|
||||
"data": marshal(segments.items, segment_fields),
|
||||
"limit": limit,
|
||||
"total": segments.total,
|
||||
"total_pages": segments.pages,
|
||||
@ -369,7 +327,7 @@ class DatasetDocumentSegmentAddApi(Resource):
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
segment = SegmentService.create_segment(payload_dict, document, dataset)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
|
||||
@ -431,12 +389,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
|
||||
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
payload_dict = payload.model_dump(exclude_none=True)
|
||||
SegmentService.segment_create_args_validate(payload_dict, document)
|
||||
|
||||
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
|
||||
segment = SegmentService.update_segment(
|
||||
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
|
||||
)
|
||||
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
|
||||
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
|
||||
@ -1,13 +1,6 @@
|
||||
from flask_restx import Resource, fields
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
from fields.hit_testing_fields import (
|
||||
child_chunk_fields,
|
||||
document_fields,
|
||||
files_fields,
|
||||
hit_testing_record_fields,
|
||||
segment_fields,
|
||||
)
|
||||
from libs.login import login_required
|
||||
|
||||
from .. import console_ns
|
||||
@ -21,45 +14,13 @@ from ..wraps import (
|
||||
register_schema_model(console_ns, HitTestingPayload)
|
||||
|
||||
|
||||
def _get_or_create_model(model_name: str, field_def):
|
||||
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
return existing
|
||||
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
document_model = _get_or_create_model("HitTestingDocument", document_fields)
|
||||
|
||||
segment_fields_copy = segment_fields.copy()
|
||||
segment_fields_copy["document"] = fields.Nested(document_model)
|
||||
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
|
||||
|
||||
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
|
||||
files_model = _get_or_create_model("HitTestingFile", files_fields)
|
||||
|
||||
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
|
||||
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
|
||||
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
|
||||
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
|
||||
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
|
||||
|
||||
# Response model for hit testing API
|
||||
hit_testing_response_fields = {
|
||||
"query": fields.String,
|
||||
"records": fields.List(fields.Nested(hit_testing_record_model)),
|
||||
}
|
||||
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
|
||||
|
||||
|
||||
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
|
||||
class HitTestingApi(Resource, DatasetsHitTestingBase):
|
||||
@console_ns.doc("test_dataset_retrieval")
|
||||
@console_ns.doc(description="Test dataset knowledge retrieval")
|
||||
@console_ns.doc(params={"dataset_id": "Dataset ID"})
|
||||
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
|
||||
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
|
||||
@console_ns.response(200, "Hit testing completed successfully")
|
||||
@console_ns.response(404, "Dataset not found")
|
||||
@console_ns.response(400, "Invalid parameters")
|
||||
@setup_required
|
||||
|
||||
@ -107,6 +107,12 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(inviter.current_tenant.id)
|
||||
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -28,6 +29,7 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -288,3 +290,31 @@ class WorkspaceInfoApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/permission")
|
||||
class WorkspacePermissionApi(Resource):
|
||||
"""Get workspace permissions for the current workspace."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Get workspace permissions from enterprise service
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
|
||||
|
||||
return {
|
||||
"workspace_id": permission.workspace_id,
|
||||
"allow_member_invite": permission.allow_member_invite,
|
||||
"allow_owner_transfer": permission.allow_owner_transfer,
|
||||
}, 200
|
||||
|
||||
@ -286,13 +286,12 @@ def enable_change_email(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# Check both billing/plan level and workspace policy level permissions
|
||||
check_workspace_owner_transfer_permission(current_tenant_id)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -23,6 +23,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -476,7 +477,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:return:
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
with session_factory.create_session() as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
|
||||
@ -3,7 +3,6 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
class PreviewDetail(BaseModel):
|
||||
content: str
|
||||
summary: str | None = None
|
||||
child_chunks: list[str] | None = None
|
||||
|
||||
|
||||
|
||||
@ -311,18 +311,14 @@ class IndexingRunner:
|
||||
qa_preview_texts: list[QAPreviewDetail] = []
|
||||
|
||||
total_segments = 0
|
||||
# doc_form represents the segmentation method (general, parent-child, QA)
|
||||
index_type = doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
# one extract_setting is one source document
|
||||
for extract_setting in extract_settings:
|
||||
# extract
|
||||
processing_rule = DatasetProcessRule(
|
||||
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
|
||||
)
|
||||
# Extract document content
|
||||
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
|
||||
# Cleaning and segmentation
|
||||
documents = index_processor.transform(
|
||||
text_docs,
|
||||
current_user=None,
|
||||
@ -365,12 +361,6 @@ class IndexingRunner:
|
||||
|
||||
if doc_form and doc_form == "qa_model":
|
||||
return IndexingEstimate(total_segments=total_segments * 20, qa_preview=qa_preview_texts, preview=[])
|
||||
|
||||
# Generate summary preview
|
||||
summary_index_setting = tmp_processing_rule.get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable") and preview_texts:
|
||||
preview_texts = index_processor.generate_summary_preview(tenant_id, preview_texts, summary_index_setting)
|
||||
|
||||
return IndexingEstimate(total_segments=total_segments, preview=preview_texts)
|
||||
|
||||
def _extract(
|
||||
|
||||
@ -72,7 +72,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = response.message.get_text_content()
|
||||
if answer is None:
|
||||
if answer == "":
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
@ -113,11 +113,9 @@ class LLMGenerator:
|
||||
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
|
||||
prompt_template = PromptTemplateParser(
|
||||
template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
prompt_template = PromptTemplateParser(template="{{histories}}\n{{format_instructions}}\nquestions:\n")
|
||||
|
||||
prompt = prompt_template.format(
|
||||
{"histories": histories, "format_instructions": format_instructions})
|
||||
prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions})
|
||||
|
||||
try:
|
||||
model_manager = ModelManager()
|
||||
@ -143,13 +141,11 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
text_content = response.message.get_text_content()
|
||||
questions = output_parser.parse(
|
||||
text_content) if text_content else []
|
||||
questions = output_parser.parse(text_content) if text_content else []
|
||||
except InvokeError:
|
||||
questions = []
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate suggested questions after answer")
|
||||
logger.exception("Failed to generate suggested questions after answer")
|
||||
questions = []
|
||||
|
||||
return questions
|
||||
@ -160,12 +156,10 @@ class LLMGenerator:
|
||||
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [],
|
||||
"opening_statement": "", "error": ""}
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
if no_variable:
|
||||
prompt_template = PromptTemplateParser(
|
||||
WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt_generate = prompt_template.format(
|
||||
inputs={
|
||||
@ -196,8 +190,7 @@ class LLMGenerator:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -252,8 +245,7 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(
|
||||
content=parameter_generate_prompt)]
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
@ -263,15 +255,13 @@ class LLMGenerator:
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(
|
||||
content=statement_generate_prompt)]
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
@ -280,15 +270,13 @@ class LLMGenerator:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content(
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content()
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
|
||||
rule_config["error"] = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
@ -298,11 +286,9 @@ class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(
|
||||
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(
|
||||
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
@ -335,8 +321,7 @@ class LLMGenerator:
|
||||
return {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get(
|
||||
"name"), code_language
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
|
||||
)
|
||||
return {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@ -350,8 +335,7 @@ class LLMGenerator:
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [SystemPromptMessage(
|
||||
content=prompt), UserPromptMessage(content=query)]
|
||||
prompt_messages: list[PromptMessage] = [SystemPromptMessage(content=prompt), UserPromptMessage(content=query)]
|
||||
|
||||
# Explicitly use the non-streaming overload
|
||||
result = model_instance.invoke_llm(
|
||||
@ -397,19 +381,16 @@ class LLMGenerator:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(
|
||||
f"Failed to parse structured output from llm: {raw_content}")
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(
|
||||
parsed_content, indent=2, ensure_ascii=False)
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@staticmethod
|
||||
@ -417,8 +398,7 @@ class LLMGenerator:
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(
|
||||
Message.created_at.desc()).first()
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
@ -466,8 +446,7 @@ class LLMGenerator:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found for the given app model.")
|
||||
last_run = workflow_service.get_node_last_run(
|
||||
app_model=app, workflow=workflow, node_id=node_id)
|
||||
last_run = workflow_service.get_node_last_run(app_model=app, workflow=workflow, node_id=node_id)
|
||||
try:
|
||||
node_type = cast(WorkflowNodeExecutionModel, last_run).node_type
|
||||
except Exception:
|
||||
@ -491,8 +470,7 @@ class LLMGenerator:
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
raw_agent_log = node_execution.execution_metadata_dict.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG, [])
|
||||
if not raw_agent_log:
|
||||
return []
|
||||
|
||||
@ -540,14 +518,11 @@ class LLMGenerator:
|
||||
ERROR_MESSAGE = "{{#error_message#}}"
|
||||
injected_instruction = instruction
|
||||
if LAST_RUN in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(
|
||||
LAST_RUN, json.dumps(last_run))
|
||||
injected_instruction = injected_instruction.replace(LAST_RUN, json.dumps(last_run))
|
||||
if CURRENT in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(
|
||||
CURRENT, current or "null")
|
||||
injected_instruction = injected_instruction.replace(CURRENT, current or "null")
|
||||
if ERROR_MESSAGE in injected_instruction:
|
||||
injected_instruction = injected_instruction.replace(
|
||||
ERROR_MESSAGE, error_message or "null")
|
||||
injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
@ -585,13 +560,11 @@ class LLMGenerator:
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(
|
||||
f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace: last_brace + 1]
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(
|
||||
f"Expected a JSON object, but got {type(data).__name__}")
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
|
||||
@ -434,20 +434,3 @@ INSTRUCTION_GENERATE_TEMPLATE_PROMPT = """The output of this prompt is not as ex
|
||||
You should edit the prompt according to the IDEAL OUTPUT."""
|
||||
|
||||
INSTRUCTION_GENERATE_TEMPLATE_CODE = """Please fix the errors in the {{#error_message#}}."""
|
||||
|
||||
DEFAULT_GENERATOR_SUMMARY_PROMPT = (
|
||||
"""Summarize the following content. Extract only the key information and main points. """
|
||||
"""Remove redundant details.
|
||||
|
||||
Requirements:
|
||||
1. Write a concise summary in plain text
|
||||
2. Use the same language as the input content
|
||||
3. Focus on important facts, concepts, and details
|
||||
4. If images are included, describe their key information
|
||||
5. Do not use words like "好的", "ok", "I understand", "This text discusses", "The content mentions"
|
||||
6. Write directly without extra words
|
||||
|
||||
Output only the summary text. Start summarizing now:
|
||||
|
||||
"""
|
||||
)
|
||||
@ -389,14 +389,15 @@ class RetrievalService:
|
||||
.all()
|
||||
}
|
||||
|
||||
records = []
|
||||
include_segment_ids = set()
|
||||
segment_child_map = {}
|
||||
|
||||
valid_dataset_documents = {}
|
||||
image_doc_ids: list[Any] = []
|
||||
child_index_node_ids = []
|
||||
index_node_ids = []
|
||||
doc_to_document_map = {}
|
||||
summary_segment_ids = set() # Track segments retrieved via summary
|
||||
|
||||
# First pass: collect all document IDs and identify summary documents
|
||||
for document in documents:
|
||||
document_id = document.metadata.get("document_id")
|
||||
if document_id not in dataset_documents:
|
||||
@ -407,24 +408,16 @@ class RetrievalService:
|
||||
continue
|
||||
valid_dataset_documents[document_id] = dataset_document
|
||||
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
|
||||
# Check if this is a summary document
|
||||
is_summary = document.metadata.get("is_summary", False)
|
||||
if is_summary:
|
||||
# For summary documents, find the original chunk via original_chunk_id
|
||||
original_chunk_id = document.metadata.get("original_chunk_id")
|
||||
if original_chunk_id:
|
||||
summary_segment_ids.add(original_chunk_id)
|
||||
continue # Skip adding to other lists for summary documents
|
||||
|
||||
if dataset_document.doc_form == IndexStructureType.PARENT_CHILD_INDEX:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
child_index_node_ids.append(doc_id)
|
||||
else:
|
||||
doc_id = document.metadata.get("doc_id") or ""
|
||||
doc_to_document_map[doc_id] = document
|
||||
if document.metadata.get("doc_type") == DocType.IMAGE:
|
||||
image_doc_ids.append(doc_id)
|
||||
else:
|
||||
@ -440,7 +433,6 @@ class RetrievalService:
|
||||
attachment_map: dict[str, list[dict[str, Any]]] = {}
|
||||
child_chunk_map: dict[str, list[ChildChunk]] = {}
|
||||
doc_segment_map: dict[str, list[str]] = {}
|
||||
segment_summary_map: dict[str, str] = {} # Map segment_id to summary content
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
attachments = cls.get_segment_attachment_infos(image_doc_ids, session)
|
||||
@ -455,7 +447,6 @@ class RetrievalService:
|
||||
doc_segment_map[attachment["segment_id"]].append(attachment["attachment_id"])
|
||||
else:
|
||||
doc_segment_map[attachment["segment_id"]] = [attachment["attachment_id"]]
|
||||
|
||||
child_chunk_stmt = select(ChildChunk).where(ChildChunk.index_node_id.in_(child_index_node_ids))
|
||||
child_index_nodes = session.execute(child_chunk_stmt).scalars().all()
|
||||
|
||||
@ -479,7 +470,6 @@ class RetrievalService:
|
||||
index_node_segments = session.execute(document_segment_stmt).scalars().all() # type: ignore
|
||||
for index_node_segment in index_node_segments:
|
||||
doc_segment_map[index_node_segment.id] = [index_node_segment.index_node_id]
|
||||
|
||||
if segment_ids:
|
||||
document_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
@ -491,42 +481,6 @@ class RetrievalService:
|
||||
if index_node_segments:
|
||||
segments.extend(index_node_segments)
|
||||
|
||||
# Handle summary documents: query segments by original_chunk_id
|
||||
if summary_segment_ids:
|
||||
summary_segment_ids_list = list(summary_segment_ids)
|
||||
summary_segment_stmt = select(DocumentSegment).where(
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.id.in_(summary_segment_ids_list),
|
||||
)
|
||||
summary_segments = session.execute(summary_segment_stmt).scalars().all() # type: ignore
|
||||
segments.extend(summary_segments)
|
||||
# Add summary segment IDs to segment_ids for summary query
|
||||
for seg in summary_segments:
|
||||
if seg.id not in segment_ids:
|
||||
segment_ids.append(seg.id)
|
||||
|
||||
# Batch query summaries for segments retrieved via summary (only enabled summaries)
|
||||
if summary_segment_ids:
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
summaries = (
|
||||
session.query(DocumentSegmentSummary)
|
||||
.filter(
|
||||
DocumentSegmentSummary.chunk_id.in_(list(summary_segment_ids)),
|
||||
DocumentSegmentSummary.status == "completed",
|
||||
DocumentSegmentSummary.enabled == True, # Only retrieve enabled summaries
|
||||
)
|
||||
.all()
|
||||
)
|
||||
for summary in summaries:
|
||||
if summary.summary_content:
|
||||
segment_summary_map[summary.chunk_id] = summary.summary_content
|
||||
|
||||
include_segment_ids = set()
|
||||
segment_child_map: dict[str, dict[str, Any]] = {}
|
||||
records: list[dict[str, Any]] = []
|
||||
|
||||
for segment in segments:
|
||||
child_chunks: list[ChildChunk] = child_chunk_map.get(segment.id, [])
|
||||
attachment_infos: list[dict[str, Any]] = attachment_map.get(segment.id, [])
|
||||
@ -539,7 +493,7 @@ class RetrievalService:
|
||||
child_chunk_details = []
|
||||
max_score = 0.0
|
||||
for child_chunk in child_chunks:
|
||||
document = doc_to_document_map.get(child_chunk.index_node_id)
|
||||
document = doc_to_document_map[child_chunk.index_node_id]
|
||||
child_chunk_detail = {
|
||||
"id": child_chunk.id,
|
||||
"content": child_chunk.content,
|
||||
@ -549,7 +503,7 @@ class RetrievalService:
|
||||
child_chunk_details.append(child_chunk_detail)
|
||||
max_score = max(max_score, document.metadata.get("score", 0.0) if document else 0.0)
|
||||
for attachment_info in attachment_infos:
|
||||
file_document = doc_to_document_map.get(attachment_info["id"])
|
||||
file_document = doc_to_document_map[attachment_info["id"]]
|
||||
max_score = max(
|
||||
max_score, file_document.metadata.get("score", 0.0) if file_document else 0.0
|
||||
)
|
||||
@ -622,16 +576,9 @@ class RetrievalService:
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract summary if this segment was retrieved via summary
|
||||
summary_content = segment_summary_map.get(segment.id)
|
||||
|
||||
# Create RetrievalSegments object
|
||||
retrieval_segment = RetrievalSegments(
|
||||
segment=segment,
|
||||
child_chunks=child_chunks_list,
|
||||
score=score,
|
||||
files=files,
|
||||
summary=summary_content,
|
||||
segment=segment, child_chunks=child_chunks_list, score=score, files=files
|
||||
)
|
||||
result.append(retrieval_segment)
|
||||
|
||||
|
||||
@ -20,4 +20,3 @@ class RetrievalSegments(BaseModel):
|
||||
child_chunks: list[RetrievalChildChunk] | None = None
|
||||
score: float | None = None
|
||||
files: list[dict[str, str | int]] | None = None
|
||||
summary: str | None = None # Summary content if retrieved via summary index
|
||||
|
||||
@ -13,7 +13,6 @@ from urllib.parse import unquote, urlparse
|
||||
import httpx
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.helper import ssrf_proxy
|
||||
from core.rag.extractor.entity.extract_setting import ExtractSetting
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
@ -46,17 +45,6 @@ class BaseIndexProcessor(ABC):
|
||||
def transform(self, documents: list[Document], current_user: Account | None = None, **kwargs) -> list[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment in preview_texts, generate a summary using LLM and attach it to the segment.
|
||||
The summary can be stored in a new attribute, e.g., summary.
|
||||
This method should be implemented by subclasses.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load(
|
||||
self,
|
||||
|
||||
@ -1,25 +1,9 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.keyword.keyword_factory import Keyword
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -33,16 +17,12 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
|
||||
from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import build_from_mapping
|
||||
from libs import helper
|
||||
from models import UploadFile
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Dataset, DatasetProcessRule
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
|
||||
class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
@ -128,29 +108,6 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
keyword.add_texts(documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
@ -270,263 +227,3 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
|
||||
}
|
||||
else:
|
||||
raise ValueError("Chunks is not a list")
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each segment, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = self.generate_summary(tenant_id, preview.content, summary_index_setting)
|
||||
preview.summary = summary
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for preview")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@staticmethod
|
||||
def generate_summary(
|
||||
tenant_id: str,
|
||||
text: str,
|
||||
summary_index_setting: dict | None = None,
|
||||
segment_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate summary for the given text using ModelInstance.invoke_llm and the default or custom summary prompt,
|
||||
and supports vision models by including images from the segment attachments or text content.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content to summarize
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_id: Optional segment ID to fetch attachments from SegmentAttachmentBinding table
|
||||
"""
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
raise ValueError("summary_index_setting is required and must be enabled to generate summary.")
|
||||
|
||||
model_name = summary_index_setting.get("model_name")
|
||||
model_provider_name = summary_index_setting.get("model_provider_name")
|
||||
summary_prompt = summary_index_setting.get("summary_prompt")
|
||||
|
||||
# Import default summary prompt
|
||||
if not summary_prompt:
|
||||
summary_prompt = DEFAULT_GENERATOR_SUMMARY_PROMPT
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id, model_provider_name, ModelType.LLM
|
||||
)
|
||||
model_instance = ModelInstance(provider_model_bundle, model_name)
|
||||
|
||||
# Get model schema to check if vision is supported
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
supports_vision = model_schema and model_schema.features and ModelFeature.VISION in model_schema.features
|
||||
|
||||
# Extract images if model supports vision
|
||||
image_files = []
|
||||
if supports_vision:
|
||||
# First, try to get images from SegmentAttachmentBinding (preferred method)
|
||||
if segment_id:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_segment_attachments(tenant_id, segment_id)
|
||||
|
||||
# If no images from attachments, fall back to extracting from text
|
||||
if not image_files:
|
||||
image_files = ParagraphIndexProcessor._extract_images_from_text(tenant_id, text)
|
||||
|
||||
# Build prompt messages
|
||||
prompt_messages = []
|
||||
|
||||
if image_files:
|
||||
# If we have images, create a UserPromptMessage with both text and images
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
|
||||
# Add images first
|
||||
for file in image_files:
|
||||
try:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=ImagePromptMessageContent.DETAIL.LOW
|
||||
)
|
||||
prompt_message_contents.append(file_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert image file to prompt message content: %s", str(e))
|
||||
continue
|
||||
|
||||
# Add text content
|
||||
if prompt_message_contents: # Only add text if we successfully added images
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=f"{summary_prompt}\n{text}"))
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
# If image conversion failed, fall back to text-only
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
else:
|
||||
# No images, use simple text prompt
|
||||
prompt = f"{summary_prompt}\n{text}"
|
||||
prompt_messages.append(UserPromptMessage(content=prompt))
|
||||
|
||||
result = model_instance.invoke_llm(prompt_messages=prompt_messages, model_parameters={}, stream=False)
|
||||
|
||||
return getattr(result.message, "content", "")
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_text(tenant_id: str, text: str) -> list[File]:
|
||||
"""
|
||||
Extract images from markdown text and convert them to File objects.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
text: Text content that may contain markdown image links
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in the text
|
||||
"""
|
||||
# Extract markdown images using regex pattern
|
||||
pattern = r"!\[.*?\]\((.*?)\)"
|
||||
images = re.findall(pattern, text)
|
||||
|
||||
if not images:
|
||||
return []
|
||||
|
||||
upload_file_id_list = []
|
||||
|
||||
for image in images:
|
||||
# For data before v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/image-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For data after v0.10.0
|
||||
pattern = r"/files/([a-f0-9\-]+)/file-preview(?:\?.*?)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
upload_file_id = match.group(1)
|
||||
upload_file_id_list.append(upload_file_id)
|
||||
continue
|
||||
|
||||
# For tools directory - direct file formats (e.g., .png, .jpg, etc.)
|
||||
pattern = r"/files/tools/([a-f0-9\-]+)\.([a-zA-Z0-9]+)(?:\?[^\s\)\"\']*)?"
|
||||
match = re.search(pattern, image)
|
||||
if match:
|
||||
# Tool files are handled differently, skip for now
|
||||
continue
|
||||
|
||||
if not upload_file_id_list:
|
||||
return []
|
||||
|
||||
# Get unique IDs for database query
|
||||
unique_upload_file_ids = list(set(upload_file_id_list))
|
||||
upload_files = (
|
||||
db.session.query(UploadFile)
|
||||
.where(UploadFile.id.in_(unique_upload_file_ids), UploadFile.tenant_id == tenant_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Create File objects from UploadFile records
|
||||
file_objects = []
|
||||
for upload_file in upload_files:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
mapping = {
|
||||
"upload_file_id": upload_file.id,
|
||||
"transfer_method": FileTransferMethod.LOCAL_FILE.value,
|
||||
"type": FileType.IMAGE.value,
|
||||
}
|
||||
|
||||
try:
|
||||
file_obj = build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@staticmethod
|
||||
def _extract_images_from_segment_attachments(tenant_id: str, segment_id: str) -> list[File]:
|
||||
"""
|
||||
Extract images from SegmentAttachmentBinding table (preferred method).
|
||||
This matches how DatasetRetrieval gets segment attachments.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
segment_id: Segment ID to fetch attachments for
|
||||
|
||||
Returns:
|
||||
List of File objects representing images found in segment attachments
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
|
||||
# Query attachments from SegmentAttachmentBinding table
|
||||
attachments_with_bindings = db.session.execute(
|
||||
select(SegmentAttachmentBinding, UploadFile)
|
||||
.join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id)
|
||||
.where(
|
||||
SegmentAttachmentBinding.segment_id == segment_id,
|
||||
SegmentAttachmentBinding.tenant_id == tenant_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not attachments_with_bindings:
|
||||
return []
|
||||
|
||||
file_objects = []
|
||||
for _, upload_file in attachments_with_bindings:
|
||||
# Only process image files
|
||||
if not upload_file.mime_type or "image" not in upload_file.mime_type:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create File object directly (similar to DatasetRetrieval)
|
||||
file_obj = File(
|
||||
id=upload_file.id,
|
||||
filename=upload_file.name,
|
||||
extension="." + upload_file.extension,
|
||||
mime_type=upload_file.mime_type,
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
remote_url=upload_file.source_url,
|
||||
related_id=upload_file.id,
|
||||
size=upload_file.size,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
file_objects.append(file_obj)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to create File object from UploadFile %s: %s", upload_file.id, str(e))
|
||||
continue
|
||||
|
||||
return file_objects
|
||||
|
||||
@ -1,13 +1,11 @@
|
||||
"""Paragraph index processor."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.model_manager import ModelInstance
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -27,9 +25,6 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.account_service import AccountService
|
||||
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
@ -140,29 +135,6 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# node_ids is segment's node_ids
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
delete_child_chunks = kwargs.get("delete_child_chunks") or False
|
||||
precomputed_child_node_ids = kwargs.get("precomputed_child_node_ids")
|
||||
@ -354,57 +326,3 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
|
||||
"preview": preview,
|
||||
"total_segments": len(parent_childs.parent_child_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
For each parent chunk in preview_texts, concurrently call generate_summary to generate a summary
|
||||
and write it to the summary attribute of PreviewDetail.
|
||||
|
||||
Note: For parent-child structure, we only generate summaries for parent chunks.
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
from flask import current_app
|
||||
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def process(preview: PreviewDetail) -> None:
|
||||
"""Generate summary for a single preview item (parent chunk)."""
|
||||
try:
|
||||
if flask_app:
|
||||
# Ensure Flask app context in worker thread
|
||||
with flask_app.app_context():
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview.summary = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=tenant_id,
|
||||
text=preview.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview.summary = summary
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for preview")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
preview.summary = None
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
list(executor.map(process, preview_texts))
|
||||
return preview_texts
|
||||
|
||||
@ -11,7 +11,6 @@ import pandas as pd
|
||||
from flask import Flask, current_app
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from core.entities.knowledge_entities import PreviewDetail
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.rag.cleaner.clean_processor import CleanProcessor
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
@ -26,10 +25,9 @@ from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.tools.utils.text_processing_utils import remove_leading_symbols
|
||||
from libs import helper
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Dataset
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.entities.knowledge_entities.knowledge_entities import Rule
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -146,30 +144,6 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
vector.create_multimodal(multimodal_documents)
|
||||
|
||||
def clean(self, dataset: Dataset, node_ids: list[str] | None, with_keywords: bool = True, **kwargs):
|
||||
# Note: Summary indexes are now disabled (not deleted) when segments are disabled.
|
||||
# This method is called for actual deletion scenarios (e.g., when segment is deleted).
|
||||
# For disable operations, disable_summaries_for_segments is called directly in the task.
|
||||
# Note: qa_model doesn't generate summaries, but we clean them for completeness
|
||||
# Only delete summaries if explicitly requested (e.g., when segment is actually deleted)
|
||||
delete_summaries = kwargs.get("delete_summaries", False)
|
||||
if delete_summaries:
|
||||
if node_ids:
|
||||
# Find segments by index_node_id
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter(
|
||||
DocumentSegment.dataset_id == dataset.id,
|
||||
DocumentSegment.index_node_id.in_(node_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
if segment_ids:
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, segment_ids)
|
||||
else:
|
||||
# Delete all summaries for the dataset
|
||||
SummaryIndexService.delete_summaries_for_segments(dataset, None)
|
||||
|
||||
vector = Vector(dataset)
|
||||
if node_ids:
|
||||
vector.delete_by_ids(node_ids)
|
||||
@ -238,17 +212,6 @@ class QAIndexProcessor(BaseIndexProcessor):
|
||||
"total_segments": len(qa_chunks.qa_chunks),
|
||||
}
|
||||
|
||||
def generate_summary_preview(
|
||||
self, tenant_id: str, preview_texts: list[PreviewDetail], summary_index_setting: dict
|
||||
) -> list[PreviewDetail]:
|
||||
"""
|
||||
QA model doesn't generate summaries, so this method returns preview_texts unchanged.
|
||||
|
||||
Note: QA model uses question-answer pairs, which don't require summary generation.
|
||||
"""
|
||||
# QA model doesn't generate summaries, return as-is
|
||||
return preview_texts
|
||||
|
||||
def _format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
|
||||
format_documents = []
|
||||
if document_node.page_content is None or not document_node.page_content.strip():
|
||||
|
||||
@ -5,7 +5,6 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -29,6 +28,21 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_resolve_user_from_request() -> Account | EndUser | None:
|
||||
"""
|
||||
Try to resolve user from Flask request context.
|
||||
|
||||
Returns None if not in a request context or if user is not available.
|
||||
"""
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
# Use _get_current_object() to dereference the proxy
|
||||
user = getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
# Check if we got a valid user object
|
||||
if user is not None and hasattr(user, "id"):
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
"""
|
||||
Workflow tool.
|
||||
@ -209,21 +223,13 @@ class WorkflowTool(Tool):
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
if has_request_context():
|
||||
return self._resolve_user_from_request()
|
||||
else:
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
# Try to resolve user from request context first
|
||||
user = _try_resolve_user_from_request()
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
def _resolve_user_from_request(self) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user from Flask request context.
|
||||
"""
|
||||
try:
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
return getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
# Fall back to database resolution
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
|
||||
22
api/core/workflow/context/__init__.py
Normal file
22
api/core/workflow/context/__init__.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""
|
||||
Execution Context - Context management for workflow execution.
|
||||
|
||||
This package provides Flask-independent context management for workflow
|
||||
execution in multi-threaded environments.
|
||||
"""
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"capture_current_context",
|
||||
]
|
||||
216
api/core/workflow/context/execution_context.py
Normal file
216
api/core/workflow/context/execution_context.py
Normal file
@ -0,0 +1,216 @@
|
||||
"""
|
||||
Execution Context - Abstracted context management for workflow execution.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
|
||||
This abstraction allows workflow execution to work with or without Flask
|
||||
by providing a common interface for application context management.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name (e.g., 'db', 'cache')."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IExecutionContext(Protocol):
|
||||
"""
|
||||
Protocol for execution context.
|
||||
|
||||
This protocol defines the interface that all execution contexts must implement,
|
||||
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "IExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
...
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
...
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for workflow execution in worker threads.
|
||||
|
||||
This class encapsulates all context needed for workflow execution:
|
||||
- Application context (Flask app or standalone)
|
||||
- Context variables for Python contextvars
|
||||
- User information (optional)
|
||||
|
||||
It is designed to be serializable and passable to worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_context: AppContext | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize execution context.
|
||||
|
||||
Args:
|
||||
app_context: Application context (Flask or standalone)
|
||||
context_vars: Python contextvars to preserve
|
||||
user: User object (optional)
|
||||
"""
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
"""Get application context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context | None:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enter this execution context.
|
||||
|
||||
This is a convenience method that creates a context manager.
|
||||
"""
|
||||
# Restore context variables if provided
|
||||
if self._context_vars:
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Enter app context if available
|
||||
if self._app_context is not None:
|
||||
with self._app_context.enter():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
"""
|
||||
Null implementation of AppContext for non-Flask environments.
|
||||
|
||||
This is used when running without Flask (e.g., in tests or standalone mode).
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize null app context.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
|
||||
def set_extension(self, name: str, extension: Any) -> None:
|
||||
"""Set extension by name."""
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
|
||||
class ExecutionContextBuilder:
|
||||
"""
|
||||
Builder for creating ExecutionContext instances.
|
||||
|
||||
This provides a fluent API for building execution contexts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._app_context: AppContext | None = None
|
||||
self._context_vars: contextvars.Context | None = None
|
||||
self._user: Any = None
|
||||
|
||||
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
|
||||
"""Set application context."""
|
||||
self._app_context = app_context
|
||||
return self
|
||||
|
||||
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
|
||||
"""Set context variables."""
|
||||
self._context_vars = context_vars
|
||||
return self
|
||||
|
||||
def with_user(self, user: Any) -> "ExecutionContextBuilder":
|
||||
"""Set user."""
|
||||
self._user = user
|
||||
return self
|
||||
|
||||
def build(self) -> ExecutionContext:
|
||||
"""Build the execution context."""
|
||||
return ExecutionContext(
|
||||
app_context=self._app_context,
|
||||
context_vars=self._context_vars,
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
from context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
@ -7,15 +7,13 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
@ -159,17 +157,8 @@ class GraphEngine:
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
# Capture execution context for worker threads
|
||||
execution_context = capture_current_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
@ -177,8 +166,7 @@ class GraphEngine:
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
execution_context=execution_context,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
|
||||
@ -5,26 +5,27 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
@ -44,8 +45,7 @@ class Worker(threading.Thread):
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
@ -56,19 +56,17 @@ class Worker(threading.Thread):
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
execution_context: Optional execution context for context preservation
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._last_task_time = time.time()
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
@ -135,11 +133,9 @@ class Worker(threading.Thread):
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
with self._execution_context:
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
|
||||
@ -8,9 +8,10 @@ DynamicScaler, and WorkerFactory into a single class.
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
@ -20,11 +21,6 @@ from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
@ -42,8 +38,7 @@ class WorkerPool:
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
@ -57,8 +52,7 @@ class WorkerPool:
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
execution_context: Optional execution context for context preservation
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
@ -67,8 +61,7 @@ class WorkerPool:
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._execution_context = execution_context
|
||||
self._layers = layers
|
||||
|
||||
# Scaling parameters with defaults
|
||||
@ -152,8 +145,7 @@ class WorkerPool:
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
|
||||
@ -62,21 +62,6 @@ class DocumentExtractorNode(Node[DocumentExtractorNodeData]):
|
||||
inputs = {"variable_selector": variable_selector}
|
||||
process_data = {"documents": value if isinstance(value, list) else [value]}
|
||||
|
||||
# Ensure storage_key is loaded for File objects
|
||||
files_to_check = value if isinstance(value, list) else [value]
|
||||
files_needing_storage_key = [
|
||||
f for f in files_to_check if isinstance(f, File) and not f.storage_key and f.related_id
|
||||
]
|
||||
if files_needing_storage_key:
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from factories.file_factory import StorageKeyLoader
|
||||
|
||||
with Session(bind=db.engine) as session:
|
||||
storage_key_loader = StorageKeyLoader(session, tenant_id=self.tenant_id)
|
||||
storage_key_loader.load_storage_keys(files_needing_storage_key)
|
||||
|
||||
try:
|
||||
if isinstance(value, list):
|
||||
extracted_text_list = list(map(_extract_text_from_file, value))
|
||||
@ -430,16 +415,6 @@ def _download_file_content(file: File) -> bytes:
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
else:
|
||||
# Check if storage_key is set
|
||||
if not file.storage_key:
|
||||
raise FileDownloadError(f"File storage_key is missing for file: {file.filename}")
|
||||
|
||||
# Check if file exists before downloading
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not storage.exists(file.storage_key):
|
||||
raise FileDownloadError(f"File not found in storage: {file.storage_key}")
|
||||
|
||||
return file_manager.download(file)
|
||||
except Exception as e:
|
||||
raise FileDownloadError(f"Error downloading file: {str(e)}") from e
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -39,7 +37,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -51,6 +48,7 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -252,8 +250,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
execution_context=self._capture_execution_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -306,11 +303,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
@ -339,6 +335,12 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _capture_execution_context(self) -> "IExecutionContext":
|
||||
"""Capture current execution context for parallel iterations."""
|
||||
from core.workflow.context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
|
||||
@ -158,5 +158,3 @@ class KnowledgeIndexNodeData(BaseNodeData):
|
||||
type: str = "knowledge-index"
|
||||
chunk_structure: str
|
||||
index_chunk_variable_selector: list[str]
|
||||
indexing_technique: str | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@ -1,11 +1,9 @@
|
||||
import concurrent.futures
|
||||
import datetime
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@ -18,9 +16,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.template import Template
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, Document, DocumentSegment, DocumentSegmentSummary
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
|
||||
from .entities import KnowledgeIndexNodeData
|
||||
from .exc import (
|
||||
@ -71,20 +67,7 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
# index knowledge
|
||||
try:
|
||||
if is_preview:
|
||||
# Preview mode: generate summaries for chunks directly without saving to database
|
||||
# Format preview and generate summaries on-the-fly
|
||||
# Get indexing_technique and summary_index_setting from node_data (workflow graph config)
|
||||
# or fallback to dataset if not available in node_data
|
||||
indexing_technique = node_data.indexing_technique or dataset.indexing_technique
|
||||
summary_index_setting = node_data.summary_index_setting or dataset.summary_index_setting
|
||||
|
||||
outputs = self._get_preview_output_with_summaries(
|
||||
node_data.chunk_structure,
|
||||
chunks,
|
||||
dataset=dataset,
|
||||
indexing_technique=indexing_technique,
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@ -180,9 +163,6 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# Generate summary index if enabled
|
||||
self._handle_summary_index_generation(dataset, document, variable_pool)
|
||||
|
||||
return {
|
||||
"dataset_id": ds_id_value,
|
||||
"dataset_name": dataset_name_value,
|
||||
@ -193,289 +173,9 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]):
|
||||
"display_status": "completed",
|
||||
}
|
||||
|
||||
def _handle_summary_index_generation(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
document: Document,
|
||||
variable_pool: VariablePool,
|
||||
) -> None:
|
||||
"""
|
||||
Handle summary index generation based on mode (debug/preview or production).
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: Document to generate summaries for
|
||||
variable_pool: Variable pool to check invoke_from
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
return
|
||||
|
||||
# Determine if in preview/debug mode
|
||||
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||
is_preview = invoke_from and invoke_from.value == InvokeFrom.DEBUGGER
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
if is_preview:
|
||||
try:
|
||||
# Query segments that need summary generation
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return
|
||||
|
||||
# Filter segments based on mode
|
||||
segments_to_process = []
|
||||
for segment in segments:
|
||||
# Skip if summary already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(chunk_id=segment.id, dataset_id=dataset.id, status="completed")
|
||||
.first()
|
||||
)
|
||||
if existing_summary:
|
||||
continue
|
||||
|
||||
# For parent-child mode, all segments are parent chunks, so process all
|
||||
segments_to_process.append(segment)
|
||||
|
||||
if not segments_to_process:
|
||||
logger.info("No segments need summary generation for document %s", document.id)
|
||||
return
|
||||
|
||||
# Use ThreadPoolExecutor for concurrent generation
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
max_workers = min(10, len(segments_to_process)) # Limit to 10 workers
|
||||
|
||||
def process_segment(segment: DocumentSegment) -> None:
|
||||
"""Process a single segment in a thread with Flask app context."""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to generate summary for segment %s",
|
||||
segment.id,
|
||||
)
|
||||
# Continue processing other segments
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
futures = [executor.submit(process_segment, segment) for segment in segments_to_process]
|
||||
# Wait for all tasks to complete
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
logger.info(
|
||||
"Successfully generated summary index for %s segments in document %s",
|
||||
len(segments_to_process),
|
||||
document.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document.id)
|
||||
# Don't fail the entire indexing process if summary generation fails
|
||||
else:
|
||||
# Production mode: asynchronous generation
|
||||
logger.info(
|
||||
"Queuing summary index generation task for document %s (production mode)",
|
||||
document.id,
|
||||
)
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||
logger.info("Summary index generation task queued for document %s", document.id)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document.id,
|
||||
)
|
||||
# Don't fail the entire indexing process if task queuing fails
|
||||
|
||||
def _get_preview_output_with_summaries(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset,
|
||||
indexing_technique: str | None = None,
|
||||
summary_index_setting: dict | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Generate preview output with summaries for chunks in preview mode.
|
||||
This method generates summaries on-the-fly without saving to database.
|
||||
|
||||
Args:
|
||||
chunk_structure: Chunk structure type
|
||||
chunks: Chunks to generate preview for
|
||||
dataset: Dataset object (for tenant_id)
|
||||
indexing_technique: Indexing technique from node config or dataset
|
||||
summary_index_setting: Summary index setting from node config or dataset
|
||||
"""
|
||||
def _get_preview_output(self, chunk_structure: str, chunks: Any) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# Check if summary index is enabled
|
||||
if indexing_technique != "high_quality":
|
||||
return preview_output
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return preview_output
|
||||
|
||||
# Generate summaries for chunks
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
chunk_count = len(preview_output["preview"])
|
||||
logger.info(
|
||||
"Generating summaries for %s chunks in preview mode (dataset: %s)",
|
||||
chunk_count,
|
||||
dataset.id,
|
||||
)
|
||||
# Use ParagraphIndexProcessor's generate_summary method
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
# Get Flask app for application context in worker threads
|
||||
flask_app = None
|
||||
try:
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
except RuntimeError:
|
||||
logger.warning("No Flask application context available, summary generation may fail")
|
||||
|
||||
def generate_summary_for_chunk(preview_item: dict) -> None:
|
||||
"""Generate summary for a single chunk."""
|
||||
if "content" in preview_item:
|
||||
try:
|
||||
# Set Flask application context in worker thread
|
||||
if flask_app:
|
||||
with flask_app.app_context():
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
else:
|
||||
# Fallback: try without app context (may fail)
|
||||
summary = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=preview_item["content"],
|
||||
summary_index_setting=summary_index_setting,
|
||||
)
|
||||
if summary:
|
||||
preview_item["summary"] = summary
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for chunk")
|
||||
# Don't fail the entire preview if summary generation fails
|
||||
|
||||
# Generate summaries concurrently using ThreadPoolExecutor
|
||||
# Set a reasonable timeout to prevent hanging (60 seconds per chunk, max 5 minutes total)
|
||||
timeout_seconds = min(300, 60 * len(preview_output["preview"]))
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=min(10, len(preview_output["preview"]))) as executor:
|
||||
futures = [
|
||||
executor.submit(generate_summary_for_chunk, preview_item)
|
||||
for preview_item in preview_output["preview"]
|
||||
]
|
||||
# Wait for all tasks to complete with timeout
|
||||
done, not_done = concurrent.futures.wait(futures, timeout=timeout_seconds)
|
||||
|
||||
# Cancel tasks that didn't complete in time
|
||||
if not_done:
|
||||
logger.warning(
|
||||
"Summary generation timeout: %s chunks did not complete within %ss. "
|
||||
"Cancelling remaining tasks...",
|
||||
len(not_done),
|
||||
timeout_seconds,
|
||||
)
|
||||
for future in not_done:
|
||||
future.cancel()
|
||||
# Wait a bit for cancellation to take effect
|
||||
concurrent.futures.wait(not_done, timeout=5)
|
||||
|
||||
completed_count = sum(1 for item in preview_output["preview"] if item.get("summary") is not None)
|
||||
logger.info(
|
||||
"Completed summary generation for preview chunks: %s/%s succeeded",
|
||||
completed_count,
|
||||
len(preview_output["preview"]),
|
||||
)
|
||||
|
||||
return preview_output
|
||||
|
||||
def _get_preview_output(
|
||||
self,
|
||||
chunk_structure: str,
|
||||
chunks: Any,
|
||||
dataset: Dataset | None = None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> Mapping[str, Any]:
|
||||
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||
preview_output = index_processor.format_preview(chunks)
|
||||
|
||||
# If dataset is provided, try to enrich preview with summaries
|
||||
if dataset and variable_pool:
|
||||
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||
if document_id:
|
||||
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||
if document:
|
||||
# Query summaries for this document
|
||||
summaries = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if summaries:
|
||||
# Create a map of segment content to summary for matching
|
||||
# Use content matching as chunks in preview might not be indexed yet
|
||||
summary_by_content = {}
|
||||
for summary in summaries:
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(id=summary.chunk_id, dataset_id=dataset.id)
|
||||
.first()
|
||||
)
|
||||
if segment:
|
||||
# Normalize content for matching (strip whitespace)
|
||||
normalized_content = segment.content.strip()
|
||||
summary_by_content[normalized_content] = summary.summary_content
|
||||
|
||||
# Enrich preview with summaries by content matching
|
||||
if "preview" in preview_output and isinstance(preview_output["preview"], list):
|
||||
matched_count = 0
|
||||
for preview_item in preview_output["preview"]:
|
||||
if "content" in preview_item:
|
||||
# Normalize content for matching
|
||||
normalized_chunk_content = preview_item["content"].strip()
|
||||
if normalized_chunk_content in summary_by_content:
|
||||
preview_item["summary"] = summary_by_content[normalized_chunk_content]
|
||||
matched_count += 1
|
||||
|
||||
if matched_count > 0:
|
||||
logger.info(
|
||||
"Enriched preview with %s existing summaries (dataset: %s, document: %s)",
|
||||
matched_count,
|
||||
dataset.id,
|
||||
document.id,
|
||||
)
|
||||
|
||||
return preview_output
|
||||
return index_processor.format_preview(chunks)
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
|
||||
@ -102,8 +102,6 @@ def init_app(app: DifyApp) -> Celery:
|
||||
imports = [
|
||||
"tasks.async_workflow_tasks", # trigger workers
|
||||
"tasks.trigger_processing_tasks", # async trigger processing
|
||||
"tasks.generate_summary_index_task", # summary index generation
|
||||
"tasks.regenerate_summary_index_task", # summary index regeneration
|
||||
]
|
||||
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
|
||||
|
||||
|
||||
@ -39,14 +39,6 @@ dataset_retrieval_model_fields = {
|
||||
"score_threshold_enabled": fields.Boolean,
|
||||
"score_threshold": fields.Float,
|
||||
}
|
||||
|
||||
dataset_summary_index_fields = {
|
||||
"enable": fields.Boolean,
|
||||
"model_name": fields.String,
|
||||
"model_provider_name": fields.String,
|
||||
"summary_prompt": fields.String,
|
||||
}
|
||||
|
||||
external_retrieval_model_fields = {
|
||||
"top_k": fields.Integer,
|
||||
"score_threshold": fields.Float,
|
||||
@ -91,7 +83,6 @@ dataset_detail_fields = {
|
||||
"embedding_model_provider": fields.String,
|
||||
"embedding_available": fields.Boolean,
|
||||
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
|
||||
"summary_index_setting": fields.Nested(dataset_summary_index_fields),
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
"doc_form": fields.String,
|
||||
"external_knowledge_info": fields.Nested(external_knowledge_info_fields),
|
||||
|
||||
@ -33,9 +33,6 @@ document_fields = {
|
||||
"hit_count": fields.Integer,
|
||||
"doc_form": fields.String,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status: "GENERATING", "COMPLETED", "ERROR", or null if not enabled
|
||||
"summary_index_status": fields.String,
|
||||
"need_summary": fields.Boolean, # Whether this document needs summary index generation
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
@ -63,9 +60,6 @@ document_with_segments_fields = {
|
||||
"completed_segments": fields.Integer,
|
||||
"total_segments": fields.Integer,
|
||||
"doc_metadata": fields.List(fields.Nested(document_metadata_fields), attribute="doc_metadata_details"),
|
||||
# Summary index generation status: "GENERATING", "COMPLETED", "ERROR", or null if not enabled
|
||||
"summary_index_status": fields.String,
|
||||
"need_summary": fields.Boolean, # Whether this document needs summary index generation
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
|
||||
@ -58,5 +58,4 @@ hit_testing_record_fields = {
|
||||
"score": fields.Float,
|
||||
"tsne_position": fields.Raw,
|
||||
"files": fields.List(fields.Nested(files_fields)),
|
||||
"summary": fields.String, # Summary content if retrieved via summary index
|
||||
}
|
||||
|
||||
@ -49,5 +49,4 @@ segment_fields = {
|
||||
"stopped_at": TimestampField,
|
||||
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
|
||||
"attachments": fields.List(fields.Nested(attachment_fields)),
|
||||
"summary": fields.String, # Summary content for the segment
|
||||
}
|
||||
|
||||
74
api/libs/workspace_permission.py
Normal file
74
api/libs/workspace_permission.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""
|
||||
Workspace permission helper functions.
|
||||
|
||||
These helpers check both billing/plan level and workspace-specific policy level permissions.
|
||||
Checks are performed at two levels:
|
||||
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
|
||||
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_workspace_member_invite_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows member invitations at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - For future expansion (currently no plan-level restriction)
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits member invitations
|
||||
"""
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_member_invite:
|
||||
raise Forbidden("Workspace policy prohibits member invitations")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
|
||||
|
||||
|
||||
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows owner transfer at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - SANDBOX plan blocks owner transfer
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
|
||||
"""
|
||||
features = FeatureService.get_features(workspace_id)
|
||||
if not features.is_allow_transfer_workspace:
|
||||
raise Forbidden("Your current plan does not allow workspace ownership transfer")
|
||||
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_owner_transfer:
|
||||
raise Forbidden("Workspace policy prohibits ownership transfer")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
|
||||
@ -1,69 +0,0 @@
|
||||
"""add SummaryIndex feature
|
||||
|
||||
Revision ID: 562dcce7d77c
|
||||
Revises: 03ea244985ce
|
||||
Create Date: 2026-01-12 13:58:40.584802
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '562dcce7d77c'
|
||||
down_revision = '03ea244985ce'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('document_segment_summary',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('document_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('chunk_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('summary_content', models.types.LongText(), nullable=True),
|
||||
sa.Column('summary_index_node_id', sa.String(length=255), nullable=True),
|
||||
sa.Column('summary_index_node_hash', sa.String(length=255), nullable=True),
|
||||
sa.Column('status', sa.String(length=32), server_default=sa.text("'generating'"), nullable=False),
|
||||
sa.Column('error', models.types.LongText(), nullable=True),
|
||||
sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('disabled_at', sa.DateTime(), nullable=True),
|
||||
sa.Column('disabled_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='document_segment_summary_pkey')
|
||||
)
|
||||
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
|
||||
batch_op.create_index('document_segment_summary_chunk_id_idx', ['chunk_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_dataset_id_idx', ['dataset_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_document_id_idx', ['document_id'], unique=False)
|
||||
batch_op.create_index('document_segment_summary_status_idx', ['status'], unique=False)
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('summary_index_setting', models.types.AdjustedJSON(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('need_summary', sa.Boolean(), server_default=sa.text('false'), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('documents', schema=None) as batch_op:
|
||||
batch_op.drop_column('need_summary')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('summary_index_setting')
|
||||
|
||||
with op.batch_alter_table('document_segment_summary', schema=None) as batch_op:
|
||||
batch_op.drop_index('document_segment_summary_status_idx')
|
||||
batch_op.drop_index('document_segment_summary_document_id_idx')
|
||||
batch_op.drop_index('document_segment_summary_dataset_id_idx')
|
||||
batch_op.drop_index('document_segment_summary_chunk_id_idx')
|
||||
|
||||
op.drop_table('document_segment_summary')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,35 @@
|
||||
"""change workflow node execution workflow_run index
|
||||
|
||||
Revision ID: 288345cd01d1
|
||||
Revises: 3334862ee907
|
||||
Create Date: 2026-01-16 17:15:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "288345cd01d1"
|
||||
down_revision = "3334862ee907"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_node_execution_workflow_run_idx")
|
||||
batch_op.create_index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
["workflow_run_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_node_execution_workflow_run_id_idx")
|
||||
batch_op.create_index(
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
["tenant_id", "app_id", "workflow_id", "triggered_from", "workflow_run_id"],
|
||||
unique=False,
|
||||
)
|
||||
@ -72,7 +72,6 @@ class Dataset(Base):
|
||||
keyword_number = mapped_column(sa.Integer, nullable=True, server_default=sa.text("10"))
|
||||
collection_binding_id = mapped_column(StringUUID, nullable=True)
|
||||
retrieval_model = mapped_column(AdjustedJSON, nullable=True)
|
||||
summary_index_setting = mapped_column(AdjustedJSON, nullable=True)
|
||||
built_in_field_enabled = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
icon_info = mapped_column(AdjustedJSON, nullable=True)
|
||||
runtime_mode = mapped_column(sa.String(255), nullable=True, server_default=sa.text("'general'"))
|
||||
@ -420,7 +419,6 @@ class Document(Base):
|
||||
doc_metadata = mapped_column(AdjustedJSON, nullable=True)
|
||||
doc_form = mapped_column(String(255), nullable=False, server_default=sa.text("'text_model'"))
|
||||
doc_language = mapped_column(String(255), nullable=True)
|
||||
need_summary: Mapped[bool | None] = mapped_column(sa.Boolean, nullable=True, server_default=sa.text("false"))
|
||||
|
||||
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
|
||||
|
||||
@ -1577,35 +1575,3 @@ class SegmentAttachmentBinding(Base):
|
||||
segment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
attachment_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
|
||||
|
||||
class DocumentSegmentSummary(Base):
|
||||
__tablename__ = "document_segment_summary"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="document_segment_summary_pkey"),
|
||||
sa.Index("document_segment_summary_dataset_id_idx", "dataset_id"),
|
||||
sa.Index("document_segment_summary_document_id_idx", "document_id"),
|
||||
sa.Index("document_segment_summary_chunk_id_idx", "chunk_id"),
|
||||
sa.Index("document_segment_summary_status_idx", "status"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, nullable=False, default=lambda: str(uuid4()))
|
||||
dataset_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
# corresponds to DocumentSegment.id or parent chunk id
|
||||
chunk_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
summary_content: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
summary_index_node_id: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
summary_index_node_hash: Mapped[str] = mapped_column(String(255), nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, server_default=sa.text("'generating'"))
|
||||
error: Mapped[str] = mapped_column(LongText, nullable=True)
|
||||
enabled: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
disabled_at: Mapped[datetime | None] = mapped_column(DateTime, nullable=True)
|
||||
disabled_by = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<DocumentSegmentSummary id={self.id} chunk_id={self.chunk_id} status={self.status}>"
|
||||
|
||||
@ -781,11 +781,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
|
||||
@ -13,6 +13,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
@ -130,6 +132,18 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
"""
|
||||
...
|
||||
|
||||
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions and offloads for the given workflow run ids.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions and offloads for the given workflow run ids.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
|
||||
@ -7,17 +7,15 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TypedDict, cast
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select, tuple_
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
)
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
@ -49,26 +47,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
@staticmethod
|
||||
def _map_run_triggered_from_to_node_triggered_from(triggered_from: str) -> str:
|
||||
"""
|
||||
Map workflow run triggered_from values to workflow node execution triggered_from values.
|
||||
"""
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.APP_RUN.value,
|
||||
WorkflowRunTriggeredFrom.DEBUGGING.value,
|
||||
WorkflowRunTriggeredFrom.SCHEDULE.value,
|
||||
WorkflowRunTriggeredFrom.PLUGIN.value,
|
||||
WorkflowRunTriggeredFrom.WEBHOOK.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
if triggered_from in {
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN.value,
|
||||
WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING.value,
|
||||
}:
|
||||
return WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN.value
|
||||
return ""
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
@ -316,51 +294,16 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
|
||||
class RunContext(TypedDict):
|
||||
run_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
triggered_from: str
|
||||
|
||||
@staticmethod
|
||||
def delete_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
def delete_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Delete node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
|
||||
Uses the composite index on (tenant_id, app_id, workflow_id, triggered_from, workflow_run_id)
|
||||
by filtering on those columns with tuple IN.
|
||||
Delete node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
"""
|
||||
if not runs:
|
||||
if not run_ids:
|
||||
return 0, 0
|
||||
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
|
||||
node_execution_ids = session.scalars(
|
||||
select(WorkflowNodeExecutionModel.id).where(
|
||||
tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
)
|
||||
).all()
|
||||
|
||||
if not node_execution_ids:
|
||||
return 0, 0
|
||||
run_ids = list(run_ids)
|
||||
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
|
||||
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
|
||||
|
||||
offloads_deleted = (
|
||||
cast(
|
||||
@ -377,55 +320,32 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
node_executions_deleted = (
|
||||
cast(
|
||||
CursorResult,
|
||||
session.execute(
|
||||
delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(node_execution_ids))
|
||||
),
|
||||
session.execute(delete(WorkflowNodeExecutionModel).where(run_id_filter)),
|
||||
).rowcount
|
||||
or 0
|
||||
)
|
||||
|
||||
return node_executions_deleted, offloads_deleted
|
||||
|
||||
@staticmethod
|
||||
def count_by_runs(session: Session, runs: Sequence[RunContext]) -> tuple[int, int]:
|
||||
def count_by_runs(self, session: Session, run_ids: Sequence[str]) -> tuple[int, int]:
|
||||
"""
|
||||
Count node executions (and offloads) for the given workflow runs using indexed columns.
|
||||
Count node executions (and offloads) for the given workflow runs using workflow_run_id.
|
||||
"""
|
||||
if not runs:
|
||||
if not run_ids:
|
||||
return 0, 0
|
||||
|
||||
tuple_values = [
|
||||
(
|
||||
run["tenant_id"],
|
||||
run["app_id"],
|
||||
run["workflow_id"],
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository._map_run_triggered_from_to_node_triggered_from(
|
||||
run["triggered_from"]
|
||||
),
|
||||
run["run_id"],
|
||||
)
|
||||
for run in runs
|
||||
]
|
||||
tuple_filter = tuple_(
|
||||
WorkflowNodeExecutionModel.tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id,
|
||||
).in_(tuple_values)
|
||||
run_ids = list(run_ids)
|
||||
run_id_filter = WorkflowNodeExecutionModel.workflow_run_id.in_(run_ids)
|
||||
|
||||
node_executions_count = (
|
||||
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(tuple_filter)) or 0
|
||||
session.scalar(select(func.count()).select_from(WorkflowNodeExecutionModel).where(run_id_filter)) or 0
|
||||
)
|
||||
node_execution_ids = select(WorkflowNodeExecutionModel.id).where(run_id_filter)
|
||||
offloads_count = (
|
||||
session.scalar(
|
||||
select(func.count())
|
||||
.select_from(WorkflowNodeExecutionOffload)
|
||||
.join(
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload.node_execution_id == WorkflowNodeExecutionModel.id,
|
||||
)
|
||||
.where(tuple_filter)
|
||||
.where(WorkflowNodeExecutionOffload.node_execution_id.in_(node_execution_ids))
|
||||
)
|
||||
or 0
|
||||
)
|
||||
|
||||
@ -1381,6 +1381,11 @@ class RegisterService:
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
|
||||
@ -87,7 +87,6 @@ from tasks.disable_segments_from_index_task import disable_segments_from_index_t
|
||||
from tasks.document_indexing_update_task import document_indexing_update_task
|
||||
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
|
||||
from tasks.recover_document_indexing_task import recover_document_indexing_task
|
||||
from tasks.regenerate_summary_index_task import regenerate_summary_index_task
|
||||
from tasks.remove_document_from_index_task import remove_document_from_index_task
|
||||
from tasks.retry_document_indexing_task import retry_document_indexing_task
|
||||
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
|
||||
@ -475,11 +474,6 @@ class DatasetService:
|
||||
if external_retrieval_model:
|
||||
dataset.retrieval_model = external_retrieval_model
|
||||
|
||||
# Update summary index setting if provided
|
||||
summary_index_setting = data.get("summary_index_setting", None)
|
||||
if summary_index_setting is not None:
|
||||
dataset.summary_index_setting = summary_index_setting
|
||||
|
||||
# Update basic dataset properties
|
||||
dataset.name = data.get("name", dataset.name)
|
||||
dataset.description = data.get("description", dataset.description)
|
||||
@ -562,18 +556,12 @@ class DatasetService:
|
||||
# Handle indexing technique changes and embedding model updates
|
||||
action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
|
||||
|
||||
# Check if summary_index_setting model changed (before updating database)
|
||||
summary_model_changed = DatasetService._check_summary_index_setting_model_changed(dataset, data)
|
||||
|
||||
# Add metadata fields
|
||||
filtered_data["updated_by"] = user.id
|
||||
filtered_data["updated_at"] = naive_utc_now()
|
||||
# update Retrieval model
|
||||
if data.get("retrieval_model"):
|
||||
filtered_data["retrieval_model"] = data["retrieval_model"]
|
||||
# update summary index setting
|
||||
if data.get("summary_index_setting"):
|
||||
filtered_data["summary_index_setting"] = data.get("summary_index_setting")
|
||||
# update icon info
|
||||
if data.get("icon_info"):
|
||||
filtered_data["icon_info"] = data.get("icon_info")
|
||||
@ -582,30 +570,12 @@ class DatasetService:
|
||||
db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
|
||||
db.session.commit()
|
||||
|
||||
# Reload dataset to get updated values
|
||||
db.session.refresh(dataset)
|
||||
|
||||
# update pipeline knowledge base node data
|
||||
DatasetService._update_pipeline_knowledge_base_node_data(dataset, user.id)
|
||||
|
||||
# Trigger vector index task if indexing technique changed
|
||||
if action:
|
||||
deal_dataset_vector_index_task.delay(dataset.id, action)
|
||||
# If embedding_model changed, also regenerate summary vectors
|
||||
if action == "update":
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Trigger summary index regeneration if summary model changed
|
||||
if summary_model_changed:
|
||||
regenerate_summary_index_task.delay(
|
||||
dataset.id,
|
||||
regenerate_reason="summary_model_changed",
|
||||
regenerate_vectors_only=False,
|
||||
)
|
||||
|
||||
return dataset
|
||||
|
||||
@ -644,7 +614,6 @@ class DatasetService:
|
||||
knowledge_index_node_data["chunk_structure"] = dataset.chunk_structure
|
||||
knowledge_index_node_data["indexing_technique"] = dataset.indexing_technique # pyright: ignore[reportAttributeAccessIssue]
|
||||
knowledge_index_node_data["keyword_number"] = dataset.keyword_number
|
||||
knowledge_index_node_data["summary_index_setting"] = dataset.summary_index_setting
|
||||
node["data"] = knowledge_index_node_data
|
||||
updated = True
|
||||
except Exception:
|
||||
@ -883,53 +852,6 @@ class DatasetService:
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@staticmethod
|
||||
def _check_summary_index_setting_model_changed(dataset: Dataset, data: dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if summary_index_setting model (model_name or model_provider_name) has changed.
|
||||
|
||||
Args:
|
||||
dataset: Current dataset object
|
||||
data: Update data dictionary
|
||||
|
||||
Returns:
|
||||
bool: True if summary model changed, False otherwise
|
||||
"""
|
||||
# Check if summary_index_setting is being updated
|
||||
if "summary_index_setting" not in data or data.get("summary_index_setting") is None:
|
||||
return False
|
||||
|
||||
new_summary_setting = data.get("summary_index_setting")
|
||||
old_summary_setting = dataset.summary_index_setting
|
||||
|
||||
# If old setting doesn't exist or is disabled, no need to regenerate
|
||||
if not old_summary_setting or not old_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# If new setting is disabled, no need to regenerate
|
||||
if not new_summary_setting or not new_summary_setting.get("enable"):
|
||||
return False
|
||||
|
||||
# Compare model_name and model_provider_name
|
||||
old_model_name = old_summary_setting.get("model_name")
|
||||
old_model_provider = old_summary_setting.get("model_provider_name")
|
||||
new_model_name = new_summary_setting.get("model_name")
|
||||
new_model_provider = new_summary_setting.get("model_provider_name")
|
||||
|
||||
# Check if model changed
|
||||
if old_model_name != new_model_name or old_model_provider != new_model_provider:
|
||||
logger.info(
|
||||
"Summary index setting model changed for dataset %s: old=%s/%s, new=%s/%s",
|
||||
dataset.id,
|
||||
old_model_provider,
|
||||
old_model_name,
|
||||
new_model_provider,
|
||||
new_model_name,
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def update_rag_pipeline_dataset_settings(
|
||||
session: Session, dataset: Dataset, knowledge_configuration: KnowledgeConfiguration, has_published: bool = False
|
||||
@ -965,9 +887,6 @@ class DatasetService:
|
||||
else:
|
||||
raise ValueError("Invalid index method")
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
else:
|
||||
if dataset.chunk_structure and dataset.chunk_structure != knowledge_configuration.chunk_structure:
|
||||
@ -1073,9 +992,6 @@ class DatasetService:
|
||||
if dataset.keyword_number != knowledge_configuration.keyword_number:
|
||||
dataset.keyword_number = knowledge_configuration.keyword_number
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_model.model_dump()
|
||||
# Update summary_index_setting if provided
|
||||
if knowledge_configuration.summary_index_setting is not None:
|
||||
dataset.summary_index_setting = knowledge_configuration.summary_index_setting
|
||||
session.add(dataset)
|
||||
session.commit()
|
||||
if action:
|
||||
@ -1908,8 +1824,6 @@ class DocumentService:
|
||||
DuplicateDocumentIndexingTaskProxy(
|
||||
dataset.tenant_id, dataset.id, duplicate_document_ids
|
||||
).delay()
|
||||
# Note: Summary index generation is triggered in document_indexing_task after indexing completes
|
||||
# to ensure segments are available. See tasks/document_indexing_task.py
|
||||
except LockNotOwnedError:
|
||||
pass
|
||||
|
||||
@ -2214,11 +2128,6 @@ class DocumentService:
|
||||
name: str,
|
||||
batch: str,
|
||||
):
|
||||
# Set need_summary based on dataset's summary_index_setting
|
||||
need_summary = False
|
||||
if dataset.summary_index_setting and dataset.summary_index_setting.get("enable") is True:
|
||||
need_summary = True
|
||||
|
||||
document = Document(
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
@ -2232,7 +2141,6 @@ class DocumentService:
|
||||
created_by=account.id,
|
||||
doc_form=document_form,
|
||||
doc_language=document_language,
|
||||
need_summary=need_summary,
|
||||
)
|
||||
doc_metadata = {}
|
||||
if dataset.built_in_field_enabled:
|
||||
@ -2457,7 +2365,6 @@ class DocumentService:
|
||||
embedding_model_provider=knowledge_config.embedding_model_provider,
|
||||
collection_binding_id=dataset_collection_binding_id,
|
||||
retrieval_model=retrieval_model.model_dump() if retrieval_model else None,
|
||||
summary_index_setting=knowledge_config.summary_index_setting,
|
||||
is_multimodal=knowledge_config.is_multimodal,
|
||||
)
|
||||
|
||||
@ -2639,14 +2546,6 @@ class DocumentService:
|
||||
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
|
||||
raise ValueError("Process rule segmentation max_tokens is invalid")
|
||||
|
||||
# valid summary index setting
|
||||
summary_index_setting = args["process_rule"].get("summary_index_setting")
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
if "model_name" not in summary_index_setting or not summary_index_setting["model_name"]:
|
||||
raise ValueError("Summary index model name is required")
|
||||
if "model_provider_name" not in summary_index_setting or not summary_index_setting["model_provider_name"]:
|
||||
raise ValueError("Summary index model provider name is required")
|
||||
|
||||
@staticmethod
|
||||
def batch_update_document_status(
|
||||
dataset: Dataset, document_ids: list[str], action: Literal["enable", "disable", "archive", "un_archive"], user
|
||||
@ -3115,39 +3014,6 @@ class SegmentService:
|
||||
if args.enabled or keyword_changed:
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided and has changed
|
||||
if args.summary is not None:
|
||||
# Check if summary index is enabled
|
||||
has_summary_index = (
|
||||
dataset.indexing_technique == "high_quality"
|
||||
and dataset.summary_index_setting
|
||||
and dataset.summary_index_setting.get("enable") is True
|
||||
)
|
||||
|
||||
if has_summary_index:
|
||||
# Query existing summary from database
|
||||
from models.dataset import DocumentSegmentSummary
|
||||
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.where(
|
||||
DocumentSegmentSummary.chunk_id == segment.id,
|
||||
DocumentSegmentSummary.dataset_id == dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Check if summary has changed
|
||||
existing_summary_content = existing_summary.summary_content if existing_summary else None
|
||||
if existing_summary_content != args.summary:
|
||||
# Summary has changed, update it
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
else:
|
||||
segment_hash = helper.generate_text_hash(content)
|
||||
tokens = 0
|
||||
@ -3222,15 +3088,6 @@ class SegmentService:
|
||||
elif document.doc_form in (IndexStructureType.PARAGRAPH_INDEX, IndexStructureType.QA_INDEX):
|
||||
# update segment vector index
|
||||
VectorService.update_segment_vector(args.keywords, segment, dataset)
|
||||
# update summary index if summary is provided
|
||||
if args.summary is not None:
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.update_summary_for_segment(segment, dataset, args.summary)
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Don't fail the entire update if summary update fails
|
||||
# update multimodel vector index
|
||||
VectorService.update_multimodel_vector(segment, args.attachment_ids or [], dataset)
|
||||
except Exception as e:
|
||||
|
||||
@ -13,6 +13,23 @@ class WebAppSettings(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class WorkspacePermission(BaseModel):
|
||||
workspace_id: str = Field(
|
||||
description="The ID of the workspace.",
|
||||
alias="workspaceId",
|
||||
)
|
||||
allow_member_invite: bool = Field(
|
||||
description="Whether to allow members to invite new members to the workspace.",
|
||||
default=False,
|
||||
alias="allowMemberInvite",
|
||||
)
|
||||
allow_owner_transfer: bool = Field(
|
||||
description="Whether to allow owners to transfer ownership of the workspace.",
|
||||
default=False,
|
||||
alias="allowOwnerTransfer",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@ -44,6 +61,16 @@ class EnterpriseService:
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {data}") from e
|
||||
|
||||
class WorkspacePermissionService:
|
||||
@classmethod
|
||||
def get_permission(cls, workspace_id: str):
|
||||
if not workspace_id:
|
||||
raise ValueError("workspace_id must be provided.")
|
||||
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
|
||||
if not data or "permission" not in data:
|
||||
raise ValueError("No data found.")
|
||||
return WorkspacePermission.model_validate(data["permission"])
|
||||
|
||||
class WebAppAuth:
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
|
||||
|
||||
@ -119,7 +119,6 @@ class KnowledgeConfig(BaseModel):
|
||||
data_source: DataSource | None = None
|
||||
process_rule: ProcessRule | None = None
|
||||
retrieval_model: RetrievalModel | None = None
|
||||
summary_index_setting: dict | None = None
|
||||
doc_form: str = "text_model"
|
||||
doc_language: str = "English"
|
||||
embedding_model: str | None = None
|
||||
@ -142,7 +141,6 @@ class SegmentUpdateArgs(BaseModel):
|
||||
regenerate_child_chunks: bool = False
|
||||
enabled: bool | None = None
|
||||
attachment_ids: list[str] | None = None
|
||||
summary: str | None = None # Summary content for summary index
|
||||
|
||||
|
||||
class ChildChunkUpdateArgs(BaseModel):
|
||||
|
||||
@ -116,8 +116,6 @@ class KnowledgeConfiguration(BaseModel):
|
||||
embedding_model: str = ""
|
||||
keyword_number: int | None = 10
|
||||
retrieval_model: RetrievalSetting
|
||||
# add summary index setting
|
||||
summary_index_setting: dict | None = None
|
||||
|
||||
@field_validator("embedding_model_provider", mode="before")
|
||||
@classmethod
|
||||
|
||||
@ -10,9 +10,7 @@ from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
|
||||
from services.billing_service import BillingService, SubscriptionPlan
|
||||
|
||||
@ -92,9 +90,12 @@ class WorkflowRunCleanup:
|
||||
paid_or_skipped = len(run_rows) - len(free_runs)
|
||||
|
||||
if not free_runs:
|
||||
skipped_message = (
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)"
|
||||
)
|
||||
click.echo(
|
||||
click.style(
|
||||
f"[batch #{batch_index}] skipped (no sandbox runs in batch, {paid_or_skipped} paid/unknown)",
|
||||
skipped_message,
|
||||
fg="yellow",
|
||||
)
|
||||
)
|
||||
@ -255,21 +256,6 @@ class WorkflowRunCleanup:
|
||||
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
|
||||
return trigger_repo.count_by_run_ids(run_ids)
|
||||
|
||||
@staticmethod
|
||||
def _build_run_contexts(
|
||||
runs: Sequence[WorkflowRun],
|
||||
) -> list[DifyAPISQLAlchemyWorkflowNodeExecutionRepository.RunContext]:
|
||||
return [
|
||||
{
|
||||
"run_id": run.id,
|
||||
"tenant_id": run.tenant_id,
|
||||
"app_id": run.app_id,
|
||||
"workflow_id": run.workflow_id,
|
||||
"triggered_from": run.triggered_from,
|
||||
}
|
||||
for run in runs
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _empty_related_counts() -> dict[str, int]:
|
||||
return {
|
||||
@ -293,9 +279,15 @@ class WorkflowRunCleanup:
|
||||
)
|
||||
|
||||
def _count_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.count_by_runs(session, run_contexts)
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
)
|
||||
return repo.count_by_runs(session, run_ids)
|
||||
|
||||
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
|
||||
run_contexts = self._build_run_contexts(runs)
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository.delete_by_runs(session, run_contexts)
|
||||
run_ids = [run.id for run in runs]
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
|
||||
)
|
||||
return repo.delete_by_runs(session, run_ids)
|
||||
|
||||
@ -1,625 +0,0 @@
|
||||
"""Summary index service for generating and managing document segment summaries."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.index_processor.constant.doc_type import DocType
|
||||
from core.rag.models.document import Document
|
||||
from extensions.ext_database import db
|
||||
from libs import helper
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SummaryIndexService:
|
||||
"""Service for generating and managing summary indexes."""
|
||||
|
||||
@staticmethod
|
||||
def generate_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Generate summary for a single segment.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to generate summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_index_setting: Summary index configuration
|
||||
|
||||
Returns:
|
||||
Generated summary text
|
||||
|
||||
Raises:
|
||||
ValueError: If summary_index_setting is invalid or generation fails
|
||||
"""
|
||||
# Reuse the existing generate_summary method from ParagraphIndexProcessor
|
||||
# Use lazy import to avoid circular import
|
||||
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
|
||||
|
||||
summary_content = ParagraphIndexProcessor.generate_summary(
|
||||
tenant_id=dataset.tenant_id,
|
||||
text=segment.content,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_id=segment.id,
|
||||
)
|
||||
|
||||
if not summary_content:
|
||||
raise ValueError("Generated summary is empty")
|
||||
|
||||
return summary_content
|
||||
|
||||
@staticmethod
|
||||
def create_summary_record(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
status: str = "generating",
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Create or update a DocumentSegmentSummary record.
|
||||
If a summary record already exists for this segment, it will be updated instead of creating a new one.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to create summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: Generated summary content
|
||||
status: Summary status (default: "generating")
|
||||
|
||||
Returns:
|
||||
Created or updated DocumentSegmentSummary instance
|
||||
"""
|
||||
# Check if summary record already exists
|
||||
existing_summary = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
|
||||
if existing_summary:
|
||||
# Update existing record
|
||||
existing_summary.summary_content = summary_content
|
||||
existing_summary.status = status
|
||||
existing_summary.error = None # Clear any previous errors
|
||||
# Re-enable if it was disabled
|
||||
if not existing_summary.enabled:
|
||||
existing_summary.enabled = True
|
||||
existing_summary.disabled_at = None
|
||||
existing_summary.disabled_by = None
|
||||
db.session.add(existing_summary)
|
||||
db.session.flush()
|
||||
return existing_summary
|
||||
else:
|
||||
# Create new record (enabled by default)
|
||||
summary_record = DocumentSegmentSummary(
|
||||
dataset_id=dataset.id,
|
||||
document_id=segment.document_id,
|
||||
chunk_id=segment.id,
|
||||
summary_content=summary_content,
|
||||
status=status,
|
||||
enabled=True, # Explicitly set enabled to True
|
||||
)
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
return summary_record
|
||||
|
||||
@staticmethod
|
||||
def vectorize_summary(
|
||||
summary_record: DocumentSegmentSummary,
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
) -> None:
|
||||
"""
|
||||
Vectorize summary and store in vector database.
|
||||
|
||||
Args:
|
||||
summary_record: DocumentSegmentSummary record
|
||||
segment: Original DocumentSegment
|
||||
dataset: Dataset containing the segment
|
||||
"""
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.warning(
|
||||
"Summary vectorization skipped for dataset %s: indexing_technique is not high_quality",
|
||||
dataset.id,
|
||||
)
|
||||
return
|
||||
|
||||
# Reuse existing index_node_id if available (like segment does), otherwise generate new one
|
||||
old_summary_node_id = summary_record.summary_index_node_id
|
||||
if old_summary_node_id:
|
||||
# Reuse existing index_node_id (like segment behavior)
|
||||
summary_index_node_id = old_summary_node_id
|
||||
else:
|
||||
# Generate new index node ID only for new summaries
|
||||
summary_index_node_id = str(uuid.uuid4())
|
||||
|
||||
# Always regenerate hash (in case summary content changed)
|
||||
summary_hash = helper.generate_text_hash(summary_record.summary_content)
|
||||
|
||||
# Delete old vector only if we're reusing the same index_node_id (to overwrite)
|
||||
# If index_node_id changed, the old vector should have been deleted elsewhere
|
||||
if old_summary_node_id and old_summary_node_id == summary_index_node_id:
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([old_summary_node_id])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete old summary vector for segment %s: %s. Continuing with new vectorization.",
|
||||
segment.id,
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Create document with summary content and metadata
|
||||
summary_document = Document(
|
||||
page_content=summary_record.summary_content,
|
||||
metadata={
|
||||
"doc_id": summary_index_node_id,
|
||||
"doc_hash": summary_hash,
|
||||
"dataset_id": dataset.id,
|
||||
"document_id": segment.document_id,
|
||||
"original_chunk_id": segment.id, # Key: link to original chunk
|
||||
"doc_type": DocType.TEXT,
|
||||
"is_summary": True, # Identifier for summary documents
|
||||
},
|
||||
)
|
||||
|
||||
# Vectorize and store with retry mechanism for connection errors
|
||||
max_retries = 3
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.add_texts([summary_document], duplicate_check=True)
|
||||
|
||||
# Success - update summary record with index node info
|
||||
summary_record.summary_index_node_id = summary_index_node_id
|
||||
summary_record.summary_index_node_hash = summary_hash
|
||||
summary_record.status = "completed"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
return # Success, exit function
|
||||
|
||||
except (ConnectionError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
# Check if it's a connection-related error that might be transient
|
||||
is_connection_error = any(
|
||||
keyword in error_str
|
||||
for keyword in [
|
||||
"connection",
|
||||
"disconnected",
|
||||
"timeout",
|
||||
"network",
|
||||
"could not connect",
|
||||
"server disconnected",
|
||||
"weaviate",
|
||||
]
|
||||
)
|
||||
|
||||
if is_connection_error and attempt < max_retries - 1:
|
||||
# Retry for connection errors
|
||||
wait_time = retry_delay * (2**attempt) # Exponential backoff
|
||||
logger.warning(
|
||||
"Vectorization attempt %s/%s failed for segment %s: %s. Retrying in %.1f seconds...",
|
||||
attempt + 1,
|
||||
max_retries,
|
||||
segment.id,
|
||||
str(e),
|
||||
wait_time,
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
# Final attempt failed or non-connection error - log and update status
|
||||
logger.error(
|
||||
"Failed to vectorize summary for segment %s after %s attempts: %s",
|
||||
segment.id,
|
||||
attempt + 1,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Vectorization failed: {str(e)}"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_and_vectorize_summary(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_index_setting: dict,
|
||||
) -> DocumentSegmentSummary:
|
||||
"""
|
||||
Generate summary for a segment and vectorize it.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to generate summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_index_setting: Summary index configuration
|
||||
|
||||
Returns:
|
||||
Created DocumentSegmentSummary instance
|
||||
|
||||
Raises:
|
||||
ValueError: If summary generation fails
|
||||
"""
|
||||
try:
|
||||
# Generate summary
|
||||
summary_content = SummaryIndexService.generate_summary_for_segment(segment, dataset, summary_index_setting)
|
||||
|
||||
# Create or update summary record (will handle overwrite internally)
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
)
|
||||
|
||||
# Vectorize summary (will delete old vector if exists before creating new one)
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Successfully generated and vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = str(e)
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def generate_summaries_for_document(
|
||||
dataset: Dataset,
|
||||
document: DatasetDocument,
|
||||
summary_index_setting: dict,
|
||||
segment_ids: list[str] | None = None,
|
||||
only_parent_chunks: bool = False,
|
||||
) -> list[DocumentSegmentSummary]:
|
||||
"""
|
||||
Generate summaries for all segments in a document including vectorization.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the document
|
||||
document: DatasetDocument to generate summaries for
|
||||
summary_index_setting: Summary index configuration
|
||||
segment_ids: Optional list of specific segment IDs to process
|
||||
only_parent_chunks: If True, only process parent chunks (for parent-child mode)
|
||||
|
||||
Returns:
|
||||
List of created DocumentSegmentSummary instances
|
||||
"""
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
"Skipping summary generation for dataset %s: indexing_technique is %s, not 'high_quality'",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
return []
|
||||
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info("Summary index is disabled for dataset %s", dataset.id)
|
||||
return []
|
||||
|
||||
# Skip qa_model documents
|
||||
if document.doc_form == "qa_model":
|
||||
logger.info("Skipping summary generation for qa_model document %s", document.id)
|
||||
return []
|
||||
|
||||
logger.info(
|
||||
"Starting summary generation for document %s in dataset %s, segment_ids: %s, only_parent_chunks: %s",
|
||||
document.id,
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
only_parent_chunks,
|
||||
)
|
||||
|
||||
# Query segments (only enabled segments)
|
||||
query = db.session.query(DocumentSegment).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
document_id=document.id,
|
||||
status="completed",
|
||||
enabled=True, # Only generate summaries for enabled segments
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegment.id.in_(segment_ids))
|
||||
|
||||
segments = query.all()
|
||||
|
||||
if not segments:
|
||||
logger.info("No segments found for document %s", document.id)
|
||||
return []
|
||||
|
||||
summary_records = []
|
||||
|
||||
for segment in segments:
|
||||
# For parent-child mode, only process parent chunks
|
||||
# In parent-child mode, all DocumentSegments are parent chunks,
|
||||
# so we process all of them. Child chunks are stored in ChildChunk table
|
||||
# and are not DocumentSegments, so they won't be in the segments list.
|
||||
# This check is mainly for clarity and future-proofing.
|
||||
if only_parent_chunks:
|
||||
# In parent-child mode, all segments in the query are parent chunks
|
||||
# Child chunks are not DocumentSegments, so they won't appear here
|
||||
# We can process all segments
|
||||
pass
|
||||
|
||||
try:
|
||||
summary_record = SummaryIndexService.generate_and_vectorize_summary(
|
||||
segment, dataset, summary_index_setting
|
||||
)
|
||||
summary_records.append(summary_record)
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary for segment %s", segment.id)
|
||||
# Continue with other segments
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Completed summary generation for document %s: %s summaries generated and vectorized",
|
||||
document.id,
|
||||
len(summary_records),
|
||||
)
|
||||
return summary_records
|
||||
|
||||
@staticmethod
|
||||
def disable_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
disabled_by: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Disable summary records and remove vectors from vector database for segments.
|
||||
Unlike delete, this preserves the summary records but marks them as disabled.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to disable summaries for. If None, disable all.
|
||||
disabled_by: User ID who disabled the summaries
|
||||
"""
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=True, # Only disable enabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Disabling %s summary records for dataset %s, segment_ids: %s",
|
||||
len(summaries),
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
)
|
||||
|
||||
# Remove from vector database (but keep records)
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
try:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids(summary_node_ids)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to remove summary vectors: %s", str(e))
|
||||
|
||||
# Disable summary records (don't delete)
|
||||
now = naive_utc_now()
|
||||
for summary in summaries:
|
||||
summary.enabled = False
|
||||
summary.disabled_at = now
|
||||
summary.disabled_by = disabled_by
|
||||
db.session.add(summary)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Disabled %s summary records for dataset %s", len(summaries), dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def enable_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Enable summary records and re-add vectors to vector database for segments.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to enable summaries for. If None, enable all.
|
||||
"""
|
||||
# Only enable summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return
|
||||
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(
|
||||
dataset_id=dataset.id,
|
||||
enabled=False, # Only enable disabled summaries
|
||||
)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Enabling %s summary records for dataset %s, segment_ids: %s",
|
||||
len(summaries),
|
||||
dataset.id,
|
||||
len(segment_ids) if segment_ids else "all",
|
||||
)
|
||||
|
||||
# Re-vectorize and re-add to vector database
|
||||
enabled_count = 0
|
||||
for summary in summaries:
|
||||
# Get the original segment
|
||||
segment = (
|
||||
db.session.query(DocumentSegment)
|
||||
.filter_by(
|
||||
id=summary.chunk_id,
|
||||
dataset_id=dataset.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not segment or not segment.enabled or segment.status != "completed":
|
||||
continue
|
||||
|
||||
if not summary.summary_content:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Re-vectorize summary
|
||||
SummaryIndexService.vectorize_summary(summary, segment, dataset)
|
||||
|
||||
# Enable summary record
|
||||
summary.enabled = True
|
||||
summary.disabled_at = None
|
||||
summary.disabled_by = None
|
||||
db.session.add(summary)
|
||||
enabled_count += 1
|
||||
except Exception:
|
||||
logger.exception("Failed to re-vectorize summary %s", summary.id)
|
||||
# Keep it disabled if vectorization fails
|
||||
continue
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Enabled %s summary records for dataset %s", enabled_count, dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def delete_summaries_for_segments(
|
||||
dataset: Dataset,
|
||||
segment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Delete summary records and vectors for segments (used only for actual deletion scenarios).
|
||||
For disable/enable operations, use disable_summaries_for_segments/enable_summaries_for_segments.
|
||||
|
||||
Args:
|
||||
dataset: Dataset containing the segments
|
||||
segment_ids: List of segment IDs to delete summaries for. If None, delete all.
|
||||
"""
|
||||
query = db.session.query(DocumentSegmentSummary).filter_by(dataset_id=dataset.id)
|
||||
|
||||
if segment_ids:
|
||||
query = query.filter(DocumentSegmentSummary.chunk_id.in_(segment_ids))
|
||||
|
||||
summaries = query.all()
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
# Delete from vector database
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_node_ids = [s.summary_index_node_id for s in summaries if s.summary_index_node_id]
|
||||
if summary_node_ids:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids(summary_node_ids)
|
||||
|
||||
# Delete summary records
|
||||
for summary in summaries:
|
||||
db.session.delete(summary)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Deleted %s summary records for dataset %s", len(summaries), dataset.id)
|
||||
|
||||
@staticmethod
|
||||
def update_summary_for_segment(
|
||||
segment: DocumentSegment,
|
||||
dataset: Dataset,
|
||||
summary_content: str,
|
||||
) -> DocumentSegmentSummary | None:
|
||||
"""
|
||||
Update summary for a segment and re-vectorize it.
|
||||
|
||||
Args:
|
||||
segment: DocumentSegment to update summary for
|
||||
dataset: Dataset containing the segment
|
||||
summary_content: New summary content
|
||||
|
||||
Returns:
|
||||
Updated DocumentSegmentSummary instance, or None if summary index is not enabled
|
||||
"""
|
||||
# Only update summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
return None
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
return None
|
||||
|
||||
# Skip qa_model documents
|
||||
if segment.document and segment.document.doc_form == "qa_model":
|
||||
return None
|
||||
|
||||
try:
|
||||
# Find existing summary record
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
|
||||
if summary_record:
|
||||
# Update existing summary
|
||||
old_summary_node_id = summary_record.summary_index_node_id
|
||||
|
||||
# Update summary content
|
||||
summary_record.summary_content = summary_content
|
||||
summary_record.status = "generating"
|
||||
db.session.add(summary_record)
|
||||
db.session.flush()
|
||||
|
||||
# Delete old vector if exists
|
||||
if old_summary_node_id:
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([old_summary_node_id])
|
||||
|
||||
# Re-vectorize summary
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
|
||||
db.session.commit()
|
||||
logger.info("Successfully updated and re-vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
else:
|
||||
# Create new summary record if doesn't exist
|
||||
summary_record = SummaryIndexService.create_summary_record(
|
||||
segment, dataset, summary_content, status="generating"
|
||||
)
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
db.session.commit()
|
||||
logger.info("Successfully created and vectorized summary for segment %s", segment.id)
|
||||
return summary_record
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to update summary for segment %s", segment.id)
|
||||
# Update summary record with error status if it exists
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary).filter_by(chunk_id=segment.id, dataset_id=dataset.id).first()
|
||||
)
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = str(e)
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
raise
|
||||
@ -117,19 +117,6 @@ def add_document_to_index_task(dataset_document_id: str):
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# Enable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for document %s: %s", dataset_document.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Document added to index: {dataset_document.id} latency: {end_at - start_at}", fg="green")
|
||||
|
||||
@ -42,7 +42,6 @@ def delete_segment_from_index_task(
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
# For actual deletion, we should delete summaries (not just disable them)
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
@ -50,7 +49,6 @@ def delete_segment_from_index_task(
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
delete_summaries=True, # Actually delete summaries when segment is deleted
|
||||
)
|
||||
if dataset.is_multimodal:
|
||||
# delete segment attachment binding
|
||||
|
||||
@ -53,18 +53,6 @@ def disable_segment_from_index_task(segment_id: str):
|
||||
logger.info(click.style(f"Segment {segment.id} document status is invalid, pass.", fg="cyan"))
|
||||
return
|
||||
|
||||
# Disable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
disabled_by=segment.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, [segment.index_node_id])
|
||||
|
||||
@ -58,26 +58,12 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
return
|
||||
|
||||
try:
|
||||
# Disable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
# Get disabled_by from first segment (they should all have the same disabled_by)
|
||||
disabled_by = segments[0].disabled_by if segments else None
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for segments: %s", str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if dataset.is_multimodal:
|
||||
segment_ids = [segment.id for segment in segments]
|
||||
segment_attachment_bindings = (
|
||||
db.session.query(SegmentAttachmentBinding)
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids_list))
|
||||
.where(SegmentAttachmentBinding.segment_id.in_(segment_ids))
|
||||
.all()
|
||||
)
|
||||
if segment_attachment_bindings:
|
||||
|
||||
@ -14,7 +14,6 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.dataset import Dataset, Document
|
||||
from services.feature_service import FeatureService
|
||||
from tasks.generate_summary_index_task import generate_summary_index_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -101,69 +100,6 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
# Trigger summary index generation for completed documents if enabled
|
||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.warning("Dataset %s not found after indexing", dataset_id)
|
||||
return
|
||||
|
||||
if dataset.indexing_technique == "high_quality":
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if summary_index_setting and summary_index_setting.get("enable"):
|
||||
# Check each document's indexing status and trigger summary generation if completed
|
||||
for document_id in document_ids:
|
||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
if document:
|
||||
logger.info(
|
||||
"Checking document %s for summary generation: status=%s, doc_form=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
)
|
||||
if document.indexing_status == "completed" and document.doc_form != "qa_model":
|
||||
try:
|
||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
||||
logger.info(
|
||||
"Queued summary index generation task for document %s in dataset %s "
|
||||
"after indexing completed",
|
||||
document_id,
|
||||
dataset.id,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to queue summary index generation task for document %s",
|
||||
document_id,
|
||||
)
|
||||
# Don't fail the entire indexing process if summary task queuing fails
|
||||
else:
|
||||
logger.info(
|
||||
"Skipping summary generation for document %s: status=%s, doc_form=%s",
|
||||
document_id,
|
||||
document.indexing_status,
|
||||
document.doc_form,
|
||||
)
|
||||
else:
|
||||
logger.warning("Document %s not found after indexing", document_id)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
||||
dataset.id,
|
||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||
dataset.id,
|
||||
dataset.indexing_technique,
|
||||
)
|
||||
except DocumentIsPausedError as ex:
|
||||
logger.info(click.style(str(ex), fg="yellow"))
|
||||
except Exception:
|
||||
|
||||
@ -103,17 +103,6 @@ def enable_segment_to_index_task(segment_id: str):
|
||||
# save vector index
|
||||
index_processor.load(dataset, [document], multimodal_documents=multimodel_documents)
|
||||
|
||||
# Enable summary index for this segment
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=[segment.id],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summary for segment %s: %s", segment.id, str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment enabled to index: {segment.id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -108,18 +108,6 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, multimodal_documents=multimodal_documents)
|
||||
|
||||
# Enable summary indexes for these segments
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
try:
|
||||
SummaryIndexService.enable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enable summaries for segments: %s", str(e))
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segments enabled to index latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
|
||||
@ -1,112 +0,0 @@
|
||||
"""Async task for generating summary indexes."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def generate_summary_index_task(dataset_id: str, document_id: str, segment_ids: list[str] | None = None):
|
||||
"""
|
||||
Async generate summary index for document segments.
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
document_id: Document ID
|
||||
segment_ids: Optional list of specific segment IDs to process. If None, process all segments.
|
||||
|
||||
Usage:
|
||||
generate_summary_index_task.delay(dataset_id, document_id)
|
||||
generate_summary_index_task.delay(dataset_id, document_id, segment_ids)
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start generating summary index for document {document_id} in dataset {dataset_id}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
|
||||
if not document:
|
||||
logger.error(click.style(f"Document not found: {document_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Only generate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary generation for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Determine if only parent chunks should be processed
|
||||
only_parent_chunks = dataset.chunk_structure == "parent_child_index"
|
||||
|
||||
# Generate summaries
|
||||
summary_records = SummaryIndexService.generate_summaries_for_document(
|
||||
dataset=dataset,
|
||||
document=document,
|
||||
summary_index_setting=summary_index_setting,
|
||||
segment_ids=segment_ids,
|
||||
only_parent_chunks=only_parent_chunks,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index generation completed for document {document_id}: "
|
||||
f"{len(summary_records)} summaries generated, latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to generate summary index for document %s", document_id)
|
||||
# Update document segments with error status if needed
|
||||
if segment_ids:
|
||||
db.session.query(DocumentSegment).filter(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
).update(
|
||||
{
|
||||
DocumentSegment.error: f"Summary generation failed: {str(e)}",
|
||||
},
|
||||
synchronize_session=False,
|
||||
)
|
||||
db.session.commit()
|
||||
finally:
|
||||
db.session.close()
|
||||
@ -1,221 +0,0 @@
|
||||
"""Task for regenerating summary indexes when dataset settings change."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def regenerate_summary_index_task(
|
||||
dataset_id: str,
|
||||
regenerate_reason: str = "summary_model_changed",
|
||||
regenerate_vectors_only: bool = False,
|
||||
):
|
||||
"""
|
||||
Regenerate summary indexes for all documents in a dataset.
|
||||
|
||||
This task is triggered when:
|
||||
1. summary_index_setting model changes (regenerate_reason="summary_model_changed")
|
||||
- Regenerates summary content and vectors for all existing summaries
|
||||
2. embedding_model changes (regenerate_reason="embedding_model_changed")
|
||||
- Only regenerates vectors for existing summaries (keeps summary content)
|
||||
|
||||
Args:
|
||||
dataset_id: Dataset ID
|
||||
regenerate_reason: Reason for regeneration ("summary_model_changed" or "embedding_model_changed")
|
||||
regenerate_vectors_only: If True, only regenerate vectors without regenerating summary content
|
||||
"""
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Start regenerate summary index for dataset {dataset_id}, reason: {regenerate_reason}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
if not dataset:
|
||||
logger.error(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Only regenerate summary index for high_quality indexing technique
|
||||
if dataset.indexing_technique != "high_quality":
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Skipping summary regeneration for dataset {dataset_id}: "
|
||||
f"indexing_technique is {dataset.indexing_technique}, not 'high_quality'",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Check if summary index is enabled
|
||||
summary_index_setting = dataset.summary_index_setting
|
||||
if not summary_index_setting or not summary_index_setting.get("enable"):
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index is disabled for dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# Get all documents with completed indexing status
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
).all()
|
||||
|
||||
if not dataset_documents:
|
||||
logger.info(
|
||||
click.style(
|
||||
f"No documents found for summary regeneration in dataset {dataset_id}",
|
||||
fg="cyan",
|
||||
)
|
||||
)
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Found %s documents for summary regeneration in dataset %s",
|
||||
len(dataset_documents),
|
||||
dataset_id,
|
||||
)
|
||||
|
||||
total_segments_processed = 0
|
||||
total_segments_failed = 0
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# Skip qa_model documents
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get all segments with existing summaries
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.join(
|
||||
DocumentSegmentSummary,
|
||||
DocumentSegment.id == DocumentSegmentSummary.chunk_id,
|
||||
)
|
||||
.where(
|
||||
DocumentSegment.document_id == dataset_document.id,
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.status == "completed",
|
||||
DocumentSegment.enabled == True,
|
||||
DocumentSegmentSummary.dataset_id == dataset_id,
|
||||
)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not segments:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Regenerating summaries for %s segments in document %s",
|
||||
len(segments),
|
||||
dataset_document.id,
|
||||
)
|
||||
|
||||
for segment in segments:
|
||||
try:
|
||||
# Get existing summary record
|
||||
summary_record = (
|
||||
db.session.query(DocumentSegmentSummary)
|
||||
.filter_by(
|
||||
chunk_id=segment.id,
|
||||
dataset_id=dataset_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not summary_record:
|
||||
logger.warning("Summary record not found for segment %s, skipping", segment.id)
|
||||
continue
|
||||
|
||||
if regenerate_vectors_only:
|
||||
# Only regenerate vectors (for embedding_model change)
|
||||
# Delete old vector
|
||||
if summary_record.summary_index_node_id:
|
||||
try:
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
|
||||
vector = Vector(dataset)
|
||||
vector.delete_by_ids([summary_record.summary_index_node_id])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to delete old summary vector for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Re-vectorize with new embedding model
|
||||
SummaryIndexService.vectorize_summary(summary_record, segment, dataset)
|
||||
db.session.commit()
|
||||
else:
|
||||
# Regenerate both summary content and vectors (for summary_model change)
|
||||
SummaryIndexService.generate_and_vectorize_summary(segment, dataset, summary_index_setting)
|
||||
db.session.commit()
|
||||
|
||||
total_segments_processed += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to regenerate summary for segment %s: %s",
|
||||
segment.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
total_segments_failed += 1
|
||||
# Update summary record with error status
|
||||
if summary_record:
|
||||
summary_record.status = "error"
|
||||
summary_record.error = f"Regeneration failed: {str(e)}"
|
||||
db.session.add(summary_record)
|
||||
db.session.commit()
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to process document %s for summary regeneration: %s",
|
||||
dataset_document.id,
|
||||
str(e),
|
||||
exc_info=True,
|
||||
)
|
||||
continue
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Summary index regeneration completed for dataset {dataset_id}: "
|
||||
f"{total_segments_processed} segments processed successfully, "
|
||||
f"{total_segments_failed} segments failed, "
|
||||
f"total documents: {len(dataset_documents)}, "
|
||||
f"latency: {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Regenerate summary index failed for dataset %s", dataset_id)
|
||||
finally:
|
||||
db.session.close()
|
||||
@ -47,21 +47,6 @@ def remove_document_from_index_task(document_id: str):
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
|
||||
# Disable summary indexes for all segments in this document
|
||||
from services.summary_index_service import SummaryIndexService
|
||||
|
||||
segment_ids_list = [segment.id for segment in segments]
|
||||
if segment_ids_list:
|
||||
try:
|
||||
SummaryIndexService.disable_summaries_for_segments(
|
||||
dataset=dataset,
|
||||
segment_ids=segment_ids_list,
|
||||
disabled_by=document.disabled_by,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to disable summaries for document %s: %s", document.id, str(e))
|
||||
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
|
||||
1
api/tests/unit_tests/core/workflow/context/__init__.py
Normal file
1
api/tests/unit_tests/core/workflow/context/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for workflow context management."""
|
||||
@ -0,0 +1,258 @@
|
||||
"""Tests for execution context module."""
|
||||
|
||||
import contextvars
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
ExecutionContextBuilder,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
)
|
||||
|
||||
|
||||
class TestAppContext:
|
||||
"""Test AppContext abstract base class."""
|
||||
|
||||
def test_app_context_is_abstract(self):
|
||||
"""Test that AppContext cannot be instantiated directly."""
|
||||
with pytest.raises(TypeError):
|
||||
AppContext() # type: ignore
|
||||
|
||||
|
||||
class TestNullAppContext:
|
||||
"""Test NullAppContext implementation."""
|
||||
|
||||
def test_null_app_context_get_config(self):
|
||||
"""Test get_config returns value from config dict."""
|
||||
config = {"key1": "value1", "key2": "value2"}
|
||||
ctx = NullAppContext(config=config)
|
||||
|
||||
assert ctx.get_config("key1") == "value1"
|
||||
assert ctx.get_config("key2") == "value2"
|
||||
|
||||
def test_null_app_context_get_config_default(self):
|
||||
"""Test get_config returns default when key not found."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
assert ctx.get_config("nonexistent", "default") == "default"
|
||||
assert ctx.get_config("nonexistent") is None
|
||||
|
||||
def test_null_app_context_get_extension(self):
|
||||
"""Test get_extension returns stored extension."""
|
||||
ctx = NullAppContext()
|
||||
extension = MagicMock()
|
||||
ctx.set_extension("db", extension)
|
||||
|
||||
assert ctx.get_extension("db") == extension
|
||||
|
||||
def test_null_app_context_get_extension_not_found(self):
|
||||
"""Test get_extension returns None when extension not found."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
assert ctx.get_extension("nonexistent") is None
|
||||
|
||||
def test_null_app_context_enter_yield(self):
|
||||
"""Test enter method yields without any side effects."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
with ctx.enter():
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Test ExecutionContext class."""
|
||||
|
||||
def test_initialization_with_all_params(self):
|
||||
"""Test ExecutionContext initialization with all parameters."""
|
||||
app_ctx = NullAppContext()
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = ExecutionContext(
|
||||
app_context=app_ctx,
|
||||
context_vars=context_vars,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_initialization_with_minimal_params(self):
|
||||
"""Test ExecutionContext initialization with minimal parameters."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
assert ctx.app_context is None
|
||||
assert ctx.context_vars is None
|
||||
assert ctx.user is None
|
||||
|
||||
def test_enter_with_context_vars(self):
|
||||
"""Test enter restores context variables."""
|
||||
test_var = contextvars.ContextVar("test_var")
|
||||
test_var.set("original_value")
|
||||
|
||||
# Copy context with the variable
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Change the variable
|
||||
test_var.set("new_value")
|
||||
|
||||
# Create execution context and enter it
|
||||
ctx = ExecutionContext(context_vars=context_vars)
|
||||
|
||||
with ctx.enter():
|
||||
# Variable should be restored to original value
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
# After exiting, variable stays at the value from within the context
|
||||
# (this is expected Python contextvars behavior)
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
def test_enter_with_app_context(self):
|
||||
"""Test enter enters app context if available."""
|
||||
app_ctx = NullAppContext()
|
||||
ctx = ExecutionContext(app_context=app_ctx)
|
||||
|
||||
# Should not raise any exception
|
||||
with ctx.enter():
|
||||
pass
|
||||
|
||||
def test_enter_without_app_context(self):
|
||||
"""Test enter works without app context."""
|
||||
ctx = ExecutionContext(app_context=None)
|
||||
|
||||
# Should not raise any exception
|
||||
with ctx.enter():
|
||||
pass
|
||||
|
||||
def test_context_manager_protocol(self):
|
||||
"""Test ExecutionContext supports context manager protocol."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
with ctx:
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
def test_user_property(self):
|
||||
"""Test user property returns set user."""
|
||||
user = MagicMock()
|
||||
ctx = ExecutionContext(user=user)
|
||||
|
||||
assert ctx.user == user
|
||||
|
||||
|
||||
class TestIExecutionContextProtocol:
|
||||
"""Test IExecutionContext protocol."""
|
||||
|
||||
def test_execution_context_implements_protocol(self):
|
||||
"""Test that ExecutionContext implements IExecutionContext protocol."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
# Should have __enter__ and __exit__ methods
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
assert hasattr(ctx, "user")
|
||||
|
||||
def test_protocol_compatibility(self):
|
||||
"""Test that ExecutionContext can be used where IExecutionContext is expected."""
|
||||
|
||||
def accept_context(context: IExecutionContext) -> Any:
|
||||
"""Function that accepts IExecutionContext protocol."""
|
||||
# Just verify it has the required protocol attributes
|
||||
assert hasattr(context, "__enter__")
|
||||
assert hasattr(context, "__exit__")
|
||||
assert hasattr(context, "user")
|
||||
return context.user
|
||||
|
||||
ctx = ExecutionContext(user="test_user")
|
||||
result = accept_context(ctx)
|
||||
|
||||
assert result == "test_user"
|
||||
|
||||
def test_protocol_with_flask_execution_context(self):
|
||||
"""Test that IExecutionContext protocol is compatible with different implementations."""
|
||||
# Verify the protocol works with ExecutionContext
|
||||
ctx = ExecutionContext(user="test_user")
|
||||
|
||||
# Should have the required protocol attributes
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
assert hasattr(ctx, "user")
|
||||
assert ctx.user == "test_user"
|
||||
|
||||
# Should work as context manager
|
||||
with ctx:
|
||||
assert ctx.user == "test_user"
|
||||
|
||||
|
||||
class TestExecutionContextBuilder:
|
||||
"""Test ExecutionContextBuilder class."""
|
||||
|
||||
def test_builder_with_all_params(self):
|
||||
"""Test builder with all parameters set."""
|
||||
app_ctx = NullAppContext()
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = (
|
||||
ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build()
|
||||
)
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_builder_with_partial_params(self):
|
||||
"""Test builder with only some parameters set."""
|
||||
app_ctx = NullAppContext()
|
||||
|
||||
ctx = ExecutionContextBuilder().with_app_context(app_ctx).build()
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars is None
|
||||
assert ctx.user is None
|
||||
|
||||
def test_builder_fluent_interface(self):
|
||||
"""Test builder provides fluent interface."""
|
||||
builder = ExecutionContextBuilder()
|
||||
|
||||
# Each method should return the builder
|
||||
assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder)
|
||||
assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder)
|
||||
assert isinstance(builder.with_user(None), ExecutionContextBuilder)
|
||||
|
||||
|
||||
class TestCaptureCurrentContext:
|
||||
"""Test capture_current_context function."""
|
||||
|
||||
def test_capture_current_context_returns_context(self):
|
||||
"""Test that capture_current_context returns a valid context."""
|
||||
from core.workflow.context.execution_context import capture_current_context
|
||||
|
||||
result = capture_current_context()
|
||||
|
||||
# Should return an object that implements IExecutionContext
|
||||
assert hasattr(result, "__enter__")
|
||||
assert hasattr(result, "__exit__")
|
||||
assert hasattr(result, "user")
|
||||
|
||||
def test_capture_current_context_captures_contextvars(self):
|
||||
"""Test that capture_current_context captures context variables."""
|
||||
# Set a context variable before capturing
|
||||
import contextvars
|
||||
|
||||
test_var = contextvars.ContextVar("capture_test_var")
|
||||
test_var.set("test_value_123")
|
||||
|
||||
from core.workflow.context.execution_context import capture_current_context
|
||||
|
||||
result = capture_current_context()
|
||||
|
||||
# Context variables should be captured
|
||||
assert result.context_vars is not None
|
||||
@ -0,0 +1,316 @@
|
||||
"""Tests for Flask app context module."""
|
||||
|
||||
import contextvars
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestFlaskAppContext:
|
||||
"""Test FlaskAppContext implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app."""
|
||||
app = MagicMock()
|
||||
app.config = {"TEST_KEY": "test_value"}
|
||||
app.extensions = {"db": MagicMock(), "cache": MagicMock()}
|
||||
app.app_context = MagicMock()
|
||||
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return app
|
||||
|
||||
def test_flask_app_context_initialization(self, mock_flask_app):
|
||||
"""Test FlaskAppContext initialization."""
|
||||
# Import here to avoid Flask dependency in test environment
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.flask_app == mock_flask_app
|
||||
|
||||
def test_flask_app_context_get_config(self, mock_flask_app):
|
||||
"""Test get_config returns Flask app config value."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_config("TEST_KEY") == "test_value"
|
||||
|
||||
def test_flask_app_context_get_config_default(self, mock_flask_app):
|
||||
"""Test get_config returns default when key not found."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_config("NONEXISTENT", "default") == "default"
|
||||
|
||||
def test_flask_app_context_get_extension(self, mock_flask_app):
|
||||
"""Test get_extension returns Flask extension."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
db_ext = mock_flask_app.extensions["db"]
|
||||
|
||||
assert ctx.get_extension("db") == db_ext
|
||||
|
||||
def test_flask_app_context_get_extension_not_found(self, mock_flask_app):
|
||||
"""Test get_extension returns None when extension not found."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_extension("nonexistent") is None
|
||||
|
||||
def test_flask_app_context_enter(self, mock_flask_app):
|
||||
"""Test enter method enters Flask app context."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
with ctx.enter():
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
# Verify app_context was called
|
||||
mock_flask_app.app_context.assert_called_once()
|
||||
|
||||
|
||||
class TestFlaskExecutionContext:
|
||||
"""Test FlaskExecutionContext class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app."""
|
||||
app = MagicMock()
|
||||
app.config = {}
|
||||
app.app_context = MagicMock()
|
||||
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return app
|
||||
|
||||
def test_initialization(self, mock_flask_app):
|
||||
"""Test FlaskExecutionContext initialization."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=context_vars,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_app_context_property(self, mock_flask_app):
|
||||
"""Test app_context property returns FlaskAppContext."""
|
||||
from context.flask_app_context import FlaskAppContext, FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
assert isinstance(ctx.app_context, FlaskAppContext)
|
||||
assert ctx.app_context.flask_app == mock_flask_app
|
||||
|
||||
def test_context_manager_protocol(self, mock_flask_app):
|
||||
"""Test FlaskExecutionContext supports context manager protocol."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
# Should have __enter__ and __exit__ methods
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
|
||||
# Should work as context manager
|
||||
with ctx:
|
||||
pass
|
||||
|
||||
|
||||
class TestCaptureFlaskContext:
|
||||
"""Test capture_flask_context function."""
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
assert ctx._flask_app == mock_app
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures user from Flask g object."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user_123"
|
||||
mock_g._login_user = mock_user
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
assert ctx.user == mock_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
|
||||
"""Test capture_flask_context uses explicit user parameter."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
explicit_user = MagicMock()
|
||||
explicit_user.id = "user_456"
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context(user=explicit_user)
|
||||
|
||||
assert ctx.user == explicit_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
|
||||
"""Test capture_flask_context captures context variables."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
# Set a context variable
|
||||
test_var = contextvars.ContextVar("test_var")
|
||||
test_var.set("test_value")
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
# Context variables should be captured
|
||||
assert ctx.context_vars is not None
|
||||
# Verify the variable is in the captured context
|
||||
captured_value = ctx.context_vars[test_var]
|
||||
assert captured_value == "test_value"
|
||||
|
||||
|
||||
class TestFlaskExecutionContextIntegration:
|
||||
"""Integration tests for FlaskExecutionContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app with proper app context."""
|
||||
app = MagicMock()
|
||||
app.config = {"TEST": "value"}
|
||||
app.extensions = {"db": MagicMock()}
|
||||
|
||||
# Mock app context
|
||||
mock_app_context = MagicMock()
|
||||
mock_app_context.__enter__ = MagicMock(return_value=None)
|
||||
mock_app_context.__exit__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value = mock_app_context
|
||||
|
||||
return app
|
||||
|
||||
def test_enter_restores_context_vars(self, mock_flask_app):
|
||||
"""Test that enter restores captured context variables."""
|
||||
# Create a context variable and set a value
|
||||
test_var = contextvars.ContextVar("integration_test_var")
|
||||
test_var.set("original_value")
|
||||
|
||||
# Capture the context
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Change the value
|
||||
test_var.set("new_value")
|
||||
|
||||
# Create FlaskExecutionContext and enter it
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=context_vars,
|
||||
)
|
||||
|
||||
with ctx:
|
||||
# Value should be restored to original
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
# After exiting, variable stays at the value from within the context
|
||||
# (this is expected Python contextvars behavior)
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
def test_enter_enters_flask_app_context(self, mock_flask_app):
|
||||
"""Test that enter enters Flask app context."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
with ctx:
|
||||
# Verify app context was entered
|
||||
assert mock_flask_app.app_context.called
|
||||
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
|
||||
"""Test that enter restores user in Flask g object."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test_user"
|
||||
|
||||
# Note: FlaskExecutionContext saves user from g before entering context,
|
||||
# then restores it after entering the app context.
|
||||
# The user passed to constructor is NOT restored to g.
|
||||
# So we need to test the actual behavior.
|
||||
|
||||
# Create FlaskExecutionContext with user in constructor
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
# Set user in g before entering (simulating existing user in g)
|
||||
mock_g._login_user = mock_user
|
||||
|
||||
with ctx:
|
||||
# After entering, the user from g before entry should be restored
|
||||
assert mock_g._login_user == mock_user
|
||||
|
||||
# The user in constructor is stored but not automatically restored to g
|
||||
# (it's available via ctx.user property)
|
||||
assert ctx.user == mock_user
|
||||
|
||||
def test_enter_method_as_context_manager(self, mock_flask_app):
|
||||
"""Test enter method returns a proper context manager."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
# enter() should return a generator/context manager
|
||||
with ctx.enter():
|
||||
# Should work without issues
|
||||
pass
|
||||
|
||||
# Verify app context was called
|
||||
assert mock_flask_app.app_context.called
|
||||
142
api/tests/unit_tests/libs/test_workspace_permission.py
Normal file
142
api/tests/unit_tests/libs/test_workspace_permission.py
Normal file
@ -0,0 +1,142 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.workspace_permission import (
|
||||
check_workspace_member_invite_permission,
|
||||
check_workspace_owner_transfer_permission,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspacePermissionHelper:
|
||||
"""Test workspace permission helper functions."""
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
|
||||
"""Community edition should always allow invitations without calling any service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# EnterpriseService should NOT be called in community edition
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
|
||||
"""Community edition should check billing plan but not call enterprise service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block invitations when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = False
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should allow invitations when workspace policy is True."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = True
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
# Enterprise service should NOT be called since billing plan already blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block transfer when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = False # Workspace policy blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_allows_transfer_when_both_enabled(
|
||||
self, mock_feature_service, mock_config, mock_enterprise_service
|
||||
):
|
||||
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = True # Workspace policy allows
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.logger")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
|
||||
"""On enterprise service error, should fail-open (allow) and log error."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Simulate enterprise service error
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
|
||||
|
||||
# Should not raise (fail-open)
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# Should log the error
|
||||
mock_logger.exception.assert_called_once()
|
||||
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)
|
||||
2
web/.gitignore
vendored
2
web/.gitignore
vendored
@ -64,3 +64,5 @@ public/fallback-*.js
|
||||
|
||||
.vscode/settings.json
|
||||
.vscode/mcp.json
|
||||
|
||||
.eslintcache
|
||||
|
||||
@ -35,6 +35,7 @@ export const baseProviderContextValue: ProviderContextState = {
|
||||
refreshLicenseLimit: noop,
|
||||
isAllowTransferWorkspace: false,
|
||||
isAllowPublishAsCustomKnowledgePipelineTemplate: false,
|
||||
humanInputEmailDeliveryEnabled: false,
|
||||
}
|
||||
|
||||
export const createMockProviderContextValue = (overrides: Partial<ProviderContextState> = {}): ProviderContextState => {
|
||||
|
||||
@ -8,7 +8,6 @@ describe('SVG Attribute Error Reproduction', () => {
|
||||
// Capture console errors
|
||||
const originalError = console.error
|
||||
let errorMessages: string[] = []
|
||||
|
||||
beforeEach(() => {
|
||||
errorMessages = []
|
||||
console.error = vi.fn((message) => {
|
||||
|
||||
259
web/app/(humanInputLayout)/form/[token]/form.tsx
Normal file
259
web/app/(humanInputLayout)/form/[token]/form.tsx
Normal file
@ -0,0 +1,259 @@
|
||||
'use client'
|
||||
import type { FormInputItem, UserAction } from '@/app/components/workflow/nodes/human-input/types'
|
||||
import {
|
||||
RiCheckboxCircleFill,
|
||||
RiInformation2Fill,
|
||||
} from '@remixicon/react'
|
||||
import { useParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Button from '@/app/components/base/button'
|
||||
import ContentItem from '@/app/components/base/chat/chat/answer/human-input-content/content-item'
|
||||
import ExpirationTime from '@/app/components/base/chat/chat/answer/human-input-content/expiration-time'
|
||||
import { getButtonStyle } from '@/app/components/base/chat/chat/answer/human-input-content/utils'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import DifyLogo from '@/app/components/base/logo/dify-logo'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { getHumanInputForm, submitHumanInputForm } from '@/service/share'
|
||||
import { asyncRunSafe } from '@/utils'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type FormData = {
|
||||
site: any
|
||||
form_content: string
|
||||
inputs: FormInputItem[]
|
||||
resolved_placeholder_values: Record<string, string>
|
||||
user_actions: UserAction[]
|
||||
expiration_time: number
|
||||
}
|
||||
|
||||
const FormContent = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const { token } = useParams<{ token: string }>()
|
||||
useDocumentTitle('')
|
||||
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [isSubmitting, setIsSubmitting] = useState(false)
|
||||
const [formData, setFormData] = useState<FormData>()
|
||||
const [contentList, setContentList] = useState<string[]>([])
|
||||
const [inputs, setInputs] = useState({})
|
||||
const [success, setSuccess] = useState(false)
|
||||
const [expired, setExpired] = useState(false)
|
||||
const [submitted, setSubmitted] = useState(false)
|
||||
|
||||
const site = formData?.site.site
|
||||
|
||||
const splitByOutputVar = (content: string): string[] => {
|
||||
const outputVarRegex = /(\{\{#\$output\.[^#]+#\}\})/g
|
||||
const parts = content.split(outputVarRegex)
|
||||
return parts.filter(part => part.length > 0)
|
||||
}
|
||||
|
||||
const initializeInputs = (formInputs: FormInputItem[]) => {
|
||||
const initialInputs: Record<string, any> = {}
|
||||
formInputs.forEach((item) => {
|
||||
if (item.type === 'text-input' || item.type === 'paragraph')
|
||||
initialInputs[item.output_variable_name] = ''
|
||||
else
|
||||
initialInputs[item.output_variable_name] = undefined
|
||||
})
|
||||
setInputs(initialInputs)
|
||||
}
|
||||
|
||||
const initializeContentList = (formContent: string) => {
|
||||
const parts = splitByOutputVar(formContent)
|
||||
setContentList(parts)
|
||||
}
|
||||
|
||||
// use immer
|
||||
const handleInputsChange = (name: string, value: any) => {
|
||||
setInputs(prev => ({
|
||||
...prev,
|
||||
[name]: value,
|
||||
}))
|
||||
}
|
||||
|
||||
const getForm = async (token: string) => {
|
||||
try {
|
||||
const data = await getHumanInputForm(token)
|
||||
setFormData(data)
|
||||
initializeInputs(data.inputs)
|
||||
initializeContentList(data.form_content)
|
||||
setIsLoading(false)
|
||||
}
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
}
|
||||
|
||||
const submit = async (actionID: string) => {
|
||||
setIsSubmitting(true)
|
||||
try {
|
||||
await submitHumanInputForm(token, { inputs, action: actionID })
|
||||
setSuccess(true)
|
||||
}
|
||||
catch (e: any) {
|
||||
if (e.status === 400) {
|
||||
const [, errRespData] = await asyncRunSafe<{ error_code: string }>(e.json())
|
||||
const { error_code } = errRespData || {}
|
||||
if (error_code === 'human_input_form_expired')
|
||||
setExpired(true)
|
||||
if (error_code === 'human_input_form_submitted')
|
||||
setSubmitted(true)
|
||||
}
|
||||
}
|
||||
finally {
|
||||
setIsSubmitting(false)
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
getForm(token)
|
||||
}, [token])
|
||||
|
||||
if (isLoading || !formData) {
|
||||
return (
|
||||
<Loading type="app" />
|
||||
)
|
||||
}
|
||||
|
||||
if (success) {
|
||||
return (
|
||||
<div className={cn('flex h-full w-full flex-col items-center justify-center')}>
|
||||
<div className="min-w-[480px] max-w-[640px]">
|
||||
<div className="border-components-divider-subtle flex h-[320px] flex-col gap-4 rounded-[20px] border bg-chat-bubble-bg p-10 pb-9 shadow-lg backdrop-blur-sm">
|
||||
<div className="h-[56px] w-[56px] shrink-0 rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge p-3">
|
||||
<RiCheckboxCircleFill className="h-8 w-8 text-text-success" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.thanks', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.recorded', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (expired) {
|
||||
return (
|
||||
<div className={cn('flex h-full w-full flex-col items-center justify-center')}>
|
||||
<div className="min-w-[480px] max-w-[640px]">
|
||||
<div className="border-components-divider-subtle flex h-[320px] flex-col gap-4 rounded-[20px] border bg-chat-bubble-bg p-10 pb-9 shadow-lg backdrop-blur-sm">
|
||||
<div className="h-[56px] w-[56px] shrink-0 rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge p-3">
|
||||
<RiInformation2Fill className="h-8 w-8 text-text-accent" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.expired', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (submitted) {
|
||||
return (
|
||||
<div className={cn('flex h-full w-full flex-col items-center justify-center')}>
|
||||
<div className="min-w-[480px] max-w-[640px]">
|
||||
<div className="border-components-divider-subtle flex h-[320px] flex-col gap-4 rounded-[20px] border bg-chat-bubble-bg p-10 pb-9 shadow-lg backdrop-blur-sm">
|
||||
<div className="h-[56px] w-[56px] shrink-0 rounded-2xl border border-components-panel-border-subtle bg-background-default-dodge p-3">
|
||||
<RiInformation2Fill className="h-8 w-8 text-text-accent" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.completed', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<div className={cn('mx-auto flex h-full w-full max-w-[720px] flex-col items-center')}>
|
||||
<div className="mt-4 flex w-full shrink-0 items-center gap-3 py-3">
|
||||
<AppIcon
|
||||
size="large"
|
||||
iconType={site.icon_type as any}
|
||||
icon={site.icon}
|
||||
background={site.icon_background}
|
||||
imageUrl={site.icon_url}
|
||||
/>
|
||||
<div className="system-xl-semibold grow text-text-primary">{site.title}</div>
|
||||
</div>
|
||||
<div className="h-0 w-full grow overflow-y-auto">
|
||||
<div className="border-components-divider-subtle rounded-[20px] border bg-chat-bubble-bg p-4 shadow-lg backdrop-blur-sm">
|
||||
{contentList.map((content, index) => (
|
||||
<ContentItem
|
||||
key={index}
|
||||
content={content}
|
||||
formInputFields={formData.inputs}
|
||||
inputs={inputs}
|
||||
onInputChange={handleInputsChange}
|
||||
resolvedPlaceholderValues={formData.resolved_placeholder_values}
|
||||
/>
|
||||
))}
|
||||
<div className="flex flex-wrap gap-1 py-1">
|
||||
{formData.user_actions.map((action: any) => (
|
||||
<Button
|
||||
key={action.id}
|
||||
disabled={isSubmitting}
|
||||
variant={getButtonStyle(action.button_style) as any}
|
||||
onClick={() => submit(action.id)}
|
||||
>
|
||||
{action.title}
|
||||
</Button>
|
||||
))}
|
||||
</div>
|
||||
<ExpirationTime expirationTime={formData.expiration_time * 1000} />
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(FormContent)
|
||||
107
web/app/(humanInputLayout)/form/[token]/mock.ts
Normal file
107
web/app/(humanInputLayout)/form/[token]/mock.ts
Normal file
@ -0,0 +1,107 @@
|
||||
import { UserActionButtonType } from '@/app/components/workflow/nodes/human-input/types'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
|
||||
export const MOCK_DATA = {
|
||||
site: {
|
||||
app_id: 'e9823576-d836-4f2b-b46f-bd4df1d82230',
|
||||
end_user_id: 'b7aa295d-1560-4d87-a828-77b3f39b30d0',
|
||||
enable_site: true,
|
||||
site: {
|
||||
title: 'wf',
|
||||
chat_color_theme: null,
|
||||
chat_color_theme_inverted: false,
|
||||
icon_type: 'emoji',
|
||||
icon: '\uD83E\uDD16',
|
||||
icon_background: '#FFEAD5',
|
||||
icon_url: null,
|
||||
description: null,
|
||||
copyright: null,
|
||||
privacy_policy: null,
|
||||
custom_disclaimer: '',
|
||||
default_language: 'en-US',
|
||||
prompt_public: false,
|
||||
show_workflow_steps: true,
|
||||
use_icon_as_answer_icon: false,
|
||||
},
|
||||
model_config: null,
|
||||
plan: 'basic',
|
||||
can_replace_logo: false,
|
||||
custom_config: null,
|
||||
},
|
||||
// 采用与上方 form editor 相同的数据结构,唯一不同的就是
|
||||
// 对于 Text 类型,其文本已完成了变量替换(即,所有使用
|
||||
// {{#node_name.var_name#}} 格式引用其他变量的位置,都
|
||||
// 被替换成了对应的变量的值)
|
||||
//
|
||||
// 参考 FormContent
|
||||
form_content: `
|
||||
# Experiencing the Four Seasons
|
||||

|
||||
|
||||
## My Seasonal Guide
|
||||
Name: {{#noddename.name#}}
|
||||
Location: {{#nodename.location#}}
|
||||
Favorite Season: {{#nodename.season#}}
|
||||
|
||||
The four seasons throughout the year:
|
||||
- Spring: Cherry blossoms, returning birds
|
||||
- Summer: Long sunny days, thunderstorms, beach adventures
|
||||
- Autumn: Red and gold foliage, harvest festivals, apple picking
|
||||
- Winter: Snowfall, frozen lakes, holiday celebrations
|
||||
## Notes
|
||||
|
||||
{{#$output.content#}}
|
||||
`,
|
||||
// 对每一个字段的描述,参考上方 FormInput 的定义。
|
||||
inputs: [
|
||||
{
|
||||
output_variable_name: 'content',
|
||||
type: InputVarType.paragraph,
|
||||
label: 'Name',
|
||||
options: [],
|
||||
max_length: 4096,
|
||||
placeholder: 'Enter your name',
|
||||
},
|
||||
{
|
||||
output_variable_name: 'location',
|
||||
type: InputVarType.textInput,
|
||||
label: 'Location',
|
||||
// placeholder: 'Enter your location',
|
||||
options: [],
|
||||
max_length: 4096,
|
||||
},
|
||||
{
|
||||
output_variable_name: 'season',
|
||||
type: InputVarType.textInput,
|
||||
label: 'Favorite Season',
|
||||
// placeholder: 'Enter your favorite season',
|
||||
options: [],
|
||||
max_length: 4096,
|
||||
},
|
||||
],
|
||||
// 用户对这个表单可采取的操作,参考上方 UserAction 的定义。
|
||||
user_actions: [
|
||||
{
|
||||
id: 'approve-action',
|
||||
title: 'Post to X',
|
||||
button_style: UserActionButtonType.Primary,
|
||||
},
|
||||
{
|
||||
id: 'regenerate-action',
|
||||
title: 'regenerate',
|
||||
button_style: UserActionButtonType.Default,
|
||||
},
|
||||
{
|
||||
id: 'thinking-action',
|
||||
title: 'thinking',
|
||||
button_style: UserActionButtonType.Accent,
|
||||
},
|
||||
{
|
||||
id: 'cancel-action',
|
||||
title: 'cancel',
|
||||
button_style: UserActionButtonType.Ghost,
|
||||
},
|
||||
],
|
||||
timeout: 3,
|
||||
timeout_unit: 'day',
|
||||
}
|
||||
13
web/app/(humanInputLayout)/form/[token]/page.tsx
Normal file
13
web/app/(humanInputLayout)/form/[token]/page.tsx
Normal file
@ -0,0 +1,13 @@
|
||||
'use client'
|
||||
import * as React from 'react'
|
||||
import FormContent from './form'
|
||||
|
||||
const FormPage = () => {
|
||||
return (
|
||||
<div className="h-full min-w-[300px] bg-chatbot-bg pb-[env(safe-area-inset-bottom)]">
|
||||
<FormContent />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(FormPage)
|
||||
@ -47,7 +47,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
await webAppLogout(shareCode!)
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router, webAppLogout, shareCode])
|
||||
}, [getSigninUrl, router, shareCode])
|
||||
|
||||
if (appInfoError) {
|
||||
return (
|
||||
|
||||
@ -31,7 +31,7 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
await webAppLogout(shareCode!)
|
||||
const url = getSigninUrl()
|
||||
router.replace(url)
|
||||
}, [getSigninUrl, router, webAppLogout, shareCode])
|
||||
}, [getSigninUrl, router, shareCode])
|
||||
|
||||
const [isLoading, setIsLoading] = useState(true)
|
||||
useEffect(() => {
|
||||
|
||||
@ -114,6 +114,7 @@ export type AppPublisherProps = {
|
||||
missingStartNode?: boolean
|
||||
hasTriggerNode?: boolean // Whether workflow currently contains any trigger nodes (used to hide missing-start CTA when triggers exist).
|
||||
startNodeLimitExceeded?: boolean
|
||||
hasHumanInputNode?: boolean
|
||||
}
|
||||
|
||||
const PUBLISH_SHORTCUT = ['ctrl', '⇧', 'P']
|
||||
@ -137,13 +138,14 @@ const AppPublisher = ({
|
||||
missingStartNode = false,
|
||||
hasTriggerNode = false,
|
||||
startNodeLimitExceeded = false,
|
||||
hasHumanInputNode = false,
|
||||
}: AppPublisherProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const [published, setPublished] = useState(false)
|
||||
const [open, setOpen] = useState(false)
|
||||
const [showAppAccessControl, setShowAppAccessControl] = useState(false)
|
||||
const [isAppAccessSet, setIsAppAccessSet] = useState(true)
|
||||
|
||||
const [embeddingModalOpen, setEmbeddingModalOpen] = useState(false)
|
||||
|
||||
const appDetail = useAppStore(state => state.appDetail)
|
||||
@ -160,6 +162,13 @@ const AppPublisher = ({
|
||||
const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && systemFeatures.webapp_auth.enabled && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS)
|
||||
const openAsyncWindow = useAsyncWindowOpen()
|
||||
|
||||
const isAppAccessSet = useMemo(() => {
|
||||
if (appDetail && appAccessSubjects) {
|
||||
return !(appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0)
|
||||
}
|
||||
return true
|
||||
}, [appAccessSubjects, appDetail])
|
||||
|
||||
const noAccessPermission = useMemo(() => systemFeatures.webapp_auth.enabled && appDetail && appDetail.access_mode !== AccessMode.EXTERNAL_MEMBERS && !userCanAccessApp?.result, [systemFeatures, appDetail, userCanAccessApp])
|
||||
const disabledFunctionButton = useMemo(() => (!publishedAt || missingStartNode || noAccessPermission), [publishedAt, missingStartNode, noAccessPermission])
|
||||
|
||||
@ -170,25 +179,13 @@ const AppPublisher = ({
|
||||
return t('noUserInputNode', { ns: 'app' })
|
||||
if (noAccessPermission)
|
||||
return t('noAccessPermission', { ns: 'app' })
|
||||
}, [missingStartNode, noAccessPermission, publishedAt])
|
||||
}, [missingStartNode, noAccessPermission, publishedAt, t])
|
||||
|
||||
useEffect(() => {
|
||||
if (systemFeatures.webapp_auth.enabled && open && appDetail)
|
||||
refetch()
|
||||
}, [open, appDetail, refetch, systemFeatures])
|
||||
|
||||
useEffect(() => {
|
||||
if (appDetail && appAccessSubjects) {
|
||||
if (appDetail.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS && appAccessSubjects.groups?.length === 0 && appAccessSubjects.members?.length === 0)
|
||||
setIsAppAccessSet(false)
|
||||
else
|
||||
setIsAppAccessSet(true)
|
||||
}
|
||||
else {
|
||||
setIsAppAccessSet(true)
|
||||
}
|
||||
}, [appAccessSubjects, appDetail])
|
||||
|
||||
const handlePublish = useCallback(async (params?: ModelAndParameter | PublishWorkflowParams) => {
|
||||
try {
|
||||
await onPublish?.(params)
|
||||
@ -466,7 +463,7 @@ const AppPublisher = ({
|
||||
{t('common.accessAPIReference', { ns: 'workflow' })}
|
||||
</SuggestedAction>
|
||||
</Tooltip>
|
||||
{appDetail?.mode === AppModeEnum.WORKFLOW && (
|
||||
{appDetail?.mode === AppModeEnum.WORKFLOW && !hasHumanInputNode && (
|
||||
<WorkflowToolConfigureButton
|
||||
disabled={workflowToolDisabled}
|
||||
published={!!toolPublished}
|
||||
|
||||
@ -271,9 +271,9 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
|
||||
</div>
|
||||
)}
|
||||
{hasVar && (
|
||||
<div className={cn('mt-1 grid px-3 pb-3')}>
|
||||
<div className="mt-1 px-3 pb-3">
|
||||
<ReactSortable
|
||||
className={cn('grid-col-1 grid space-y-1', readonly && 'grid-cols-2 gap-1 space-y-0')}
|
||||
className="space-y-1"
|
||||
list={promptVariablesWithIds}
|
||||
setList={(list) => { onPromptVariablesChange?.(list.map(item => item.variable)) }}
|
||||
handle=".handle"
|
||||
|
||||
@ -39,7 +39,7 @@ const VarItem: FC<ItemProps> = ({
|
||||
const [isDeleting, setIsDeleting] = useState(false)
|
||||
|
||||
return (
|
||||
<div className={cn('group relative mb-1 flex h-[34px] w-full items-center rounded-lg border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg pl-2.5 pr-3 shadow-xs last-of-type:mb-0 hover:bg-components-panel-on-panel-item-bg-hover hover:shadow-sm', isDeleting && 'border-state-destructive-border hover:bg-state-destructive-hover', readonly && 'cursor-not-allowed', className)}>
|
||||
<div className={cn('group relative mb-1 flex h-[34px] w-full items-center rounded-lg border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg pl-2.5 pr-3 shadow-xs last-of-type:mb-0 hover:bg-components-panel-on-panel-item-bg-hover hover:shadow-sm', isDeleting && 'border-state-destructive-border hover:bg-state-destructive-hover', readonly && 'cursor-not-allowed opacity-30', className)}>
|
||||
<VarIcon className={cn('mr-1 h-4 w-4 shrink-0 text-text-accent', canDrag && 'group-hover:opacity-0')} />
|
||||
{canDrag && (
|
||||
<RiDraggable className="absolute left-3 top-3 hidden h-3 w-3 cursor-pointer text-text-tertiary group-hover:block" />
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback } from 'react'
|
||||
@ -11,17 +10,14 @@ import { useFeatures, useFeaturesStore } from '@/app/components/base/features/ho
|
||||
import { Vision } from '@/app/components/base/icons/src/vender/features'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
// import OptionCard from '@/app/components/workflow/nodes/_base/components/option-card'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { Resolution } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ParamConfig from './param-config'
|
||||
|
||||
const ConfigVision: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { isShowVisionConfig, isAllowVideoUpload, readonly } = useContext(ConfigContext)
|
||||
const { isShowVisionConfig, isAllowVideoUpload } = useContext(ConfigContext)
|
||||
const file = useFeatures(s => s.features.file)
|
||||
const featuresStore = useFeaturesStore()
|
||||
|
||||
@ -58,7 +54,7 @@ const ConfigVision: FC = () => {
|
||||
setFeatures(newFeatures)
|
||||
}, [featuresStore, isAllowVideoUpload])
|
||||
|
||||
if (!isShowVisionConfig || (readonly && !isImageEnabled))
|
||||
if (!isShowVisionConfig)
|
||||
return null
|
||||
|
||||
return (
|
||||
@ -79,55 +75,37 @@ const ConfigVision: FC = () => {
|
||||
/>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center">
|
||||
{readonly
|
||||
? (
|
||||
<>
|
||||
<div className="mr-2 flex items-center gap-0.5">
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t('vision.visionSettings.resolution', { ns: 'appDebug' })}</div>
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[180px]">
|
||||
{t('vision.visionSettings.resolutionTooltip', { ns: 'appDebug' }).split('\n').map(item => (
|
||||
<div key={item}>{item}</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center gap-1">
|
||||
<OptionCard
|
||||
title={t('vision.visionSettings.high', { ns: 'appDebug' })}
|
||||
selected={file?.image?.detail === Resolution.high}
|
||||
onSelect={noop}
|
||||
className={cn(
|
||||
'cursor-not-allowed rounded-lg px-3 hover:shadow-none',
|
||||
file?.image?.detail !== Resolution.high && 'hover:border-components-option-card-option-border',
|
||||
)}
|
||||
/>
|
||||
<OptionCard
|
||||
title={t('vision.visionSettings.low', { ns: 'appDebug' })}
|
||||
selected={file?.image?.detail === Resolution.low}
|
||||
onSelect={noop}
|
||||
className={cn(
|
||||
'cursor-not-allowed rounded-lg px-3 hover:shadow-none',
|
||||
file?.image?.detail !== Resolution.low && 'hover:border-components-option-card-option-border',
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
<ParamConfig />
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-regular"></div>
|
||||
<Switch
|
||||
defaultValue={isImageEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* <div className='mr-2 flex items-center gap-0.5'>
|
||||
<div className='text-text-tertiary system-xs-medium-uppercase'>{t('appDebug.vision.visionSettings.resolution')}</div>
|
||||
<Tooltip
|
||||
popupContent={
|
||||
<div className='w-[180px]' >
|
||||
{t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => (
|
||||
<div key={item}>{item}</div>
|
||||
))}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div> */}
|
||||
{/* <div className='flex items-center gap-1'>
|
||||
<OptionCard
|
||||
title={t('appDebug.vision.visionSettings.high')}
|
||||
selected={file?.image?.detail === Resolution.high}
|
||||
onSelect={() => handleChange(Resolution.high)}
|
||||
/>
|
||||
<OptionCard
|
||||
title={t('appDebug.vision.visionSettings.low')}
|
||||
selected={file?.image?.detail === Resolution.low}
|
||||
onSelect={() => handleChange(Resolution.low)}
|
||||
/>
|
||||
</div> */}
|
||||
<ParamConfig />
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-regular"></div>
|
||||
<Switch
|
||||
defaultValue={isImageEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -40,7 +40,7 @@ type AgentToolWithMoreInfo = AgentTool & { icon: any, collection?: Collection }
|
||||
const AgentTools: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const [isShowChooseTool, setIsShowChooseTool] = useState(false)
|
||||
const { readonly, modelConfig, setModelConfig } = useContext(ConfigContext)
|
||||
const { modelConfig, setModelConfig } = useContext(ConfigContext)
|
||||
const { data: buildInTools } = useAllBuiltInTools()
|
||||
const { data: customTools } = useAllCustomTools()
|
||||
const { data: workflowTools } = useAllWorkflowTools()
|
||||
@ -168,10 +168,10 @@ const AgentTools: FC = () => {
|
||||
{tools.filter(item => !!item.enabled).length}
|
||||
/
|
||||
{tools.length}
|
||||
|
||||
|
||||
{t('agent.tools.enabled', { ns: 'appDebug' })}
|
||||
</div>
|
||||
{tools.length < MAX_TOOLS_NUM && !readonly && (
|
||||
{tools.length < MAX_TOOLS_NUM && (
|
||||
<>
|
||||
<div className="ml-3 mr-1 h-3.5 w-px bg-divider-regular"></div>
|
||||
<ToolPicker
|
||||
@ -189,7 +189,7 @@ const AgentTools: FC = () => {
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<div className={cn('grid grid-cols-1 items-center gap-1 2xl:grid-cols-2', readonly && 'cursor-not-allowed grid-cols-2')}>
|
||||
<div className="grid grid-cols-1 flex-wrap items-center justify-between gap-1 2xl:grid-cols-2">
|
||||
{tools.map((item: AgentTool & { icon: any, collection?: Collection }, index) => (
|
||||
<div
|
||||
key={index}
|
||||
@ -214,7 +214,7 @@ const AgentTools: FC = () => {
|
||||
>
|
||||
<span className="system-xs-medium pr-1.5 text-text-secondary">{getProviderShowName(item)}</span>
|
||||
<span className="text-text-tertiary">{item.tool_label}</span>
|
||||
{!item.isDeleted && !readonly && (
|
||||
{!item.isDeleted && (
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
<div className="w-[180px]">
|
||||
@ -259,7 +259,7 @@ const AgentTools: FC = () => {
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{!item.isDeleted && !readonly && (
|
||||
{!item.isDeleted && (
|
||||
<div className="mr-2 hidden items-center gap-1 group-hover:flex">
|
||||
{!item.notAuthor && (
|
||||
<Tooltip
|
||||
@ -298,7 +298,7 @@ const AgentTools: FC = () => {
|
||||
{!item.notAuthor && (
|
||||
<Switch
|
||||
defaultValue={item.isDeleted ? false : item.enabled}
|
||||
disabled={item.isDeleted || readonly}
|
||||
disabled={item.isDeleted}
|
||||
size="md"
|
||||
onChange={(enabled) => {
|
||||
const newModelConfig = produce(modelConfig, (draft) => {
|
||||
@ -312,7 +312,6 @@ const AgentTools: FC = () => {
|
||||
{item.notAuthor && (
|
||||
<Button
|
||||
variant="secondary"
|
||||
disabled={readonly}
|
||||
size="small"
|
||||
onClick={() => {
|
||||
setCurrentTool(item)
|
||||
|
||||
@ -17,7 +17,7 @@ const ConfigAudio: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const file = useFeatures(s => s.features.file)
|
||||
const featuresStore = useFeaturesStore()
|
||||
const { isShowAudioConfig, readonly } = useContext(ConfigContext)
|
||||
const { isShowAudioConfig } = useContext(ConfigContext)
|
||||
|
||||
const isAudioEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.audio) ?? false
|
||||
|
||||
@ -45,7 +45,7 @@ const ConfigAudio: FC = () => {
|
||||
setFeatures(newFeatures)
|
||||
}, [featuresStore])
|
||||
|
||||
if (!isShowAudioConfig || (readonly && !isAudioEnabled))
|
||||
if (!isShowAudioConfig)
|
||||
return null
|
||||
|
||||
return (
|
||||
@ -65,16 +65,14 @@ const ConfigAudio: FC = () => {
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
{!readonly && (
|
||||
<div className="flex shrink-0 items-center">
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-subtle"></div>
|
||||
<Switch
|
||||
defaultValue={isAudioEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex shrink-0 items-center">
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-subtle"></div>
|
||||
<Switch
|
||||
defaultValue={isAudioEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ const ConfigDocument: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const file = useFeatures(s => s.features.file)
|
||||
const featuresStore = useFeaturesStore()
|
||||
const { isShowDocumentConfig, readonly } = useContext(ConfigContext)
|
||||
const { isShowDocumentConfig } = useContext(ConfigContext)
|
||||
|
||||
const isDocumentEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.document) ?? false
|
||||
|
||||
@ -45,7 +45,7 @@ const ConfigDocument: FC = () => {
|
||||
setFeatures(newFeatures)
|
||||
}, [featuresStore])
|
||||
|
||||
if (!isShowDocumentConfig || (readonly && !isDocumentEnabled))
|
||||
if (!isShowDocumentConfig)
|
||||
return null
|
||||
|
||||
return (
|
||||
@ -65,16 +65,14 @@ const ConfigDocument: FC = () => {
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
{!readonly && (
|
||||
<div className="flex shrink-0 items-center">
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-subtle"></div>
|
||||
<Switch
|
||||
defaultValue={isDocumentEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex shrink-0 items-center">
|
||||
<div className="ml-1 mr-3 h-3.5 w-[1px] bg-divider-subtle"></div>
|
||||
<Switch
|
||||
defaultValue={isDocumentEnabled}
|
||||
onChange={handleChange}
|
||||
size="md"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@ -18,7 +18,6 @@ import ConfigDocument from './config-document'
|
||||
|
||||
const Config: FC = () => {
|
||||
const {
|
||||
readonly,
|
||||
mode,
|
||||
isAdvancedMode,
|
||||
modelModeType,
|
||||
@ -28,7 +27,6 @@ const Config: FC = () => {
|
||||
modelConfig,
|
||||
setModelConfig,
|
||||
setPrevPromptConfig,
|
||||
dataSets,
|
||||
} = useContext(ConfigContext)
|
||||
const isChatApp = [AppModeEnum.ADVANCED_CHAT, AppModeEnum.AGENT_CHAT, AppModeEnum.CHAT].includes(mode)
|
||||
const formattingChangedDispatcher = useFormattingChangedDispatcher()
|
||||
@ -67,27 +65,19 @@ const Config: FC = () => {
|
||||
promptTemplate={promptTemplate}
|
||||
promptVariables={promptVariables}
|
||||
onChange={handlePromptChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
|
||||
{/* Variables */}
|
||||
{!(readonly && promptVariables.length === 0) && (
|
||||
<ConfigVar
|
||||
promptVariables={promptVariables}
|
||||
onPromptVariablesChange={handlePromptVariablesNameChange}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)}
|
||||
<ConfigVar
|
||||
promptVariables={promptVariables}
|
||||
onPromptVariablesChange={handlePromptVariablesNameChange}
|
||||
/>
|
||||
|
||||
{/* Dataset */}
|
||||
{!(readonly && dataSets.length === 0) && (
|
||||
<DatasetConfig
|
||||
readonly={readonly}
|
||||
hideMetadataFilter={readonly}
|
||||
/>
|
||||
)}
|
||||
<DatasetConfig />
|
||||
|
||||
{/* Tools */}
|
||||
{isAgent && !(readonly && modelConfig.agentConfig.tools.length === 0) && (
|
||||
{isAgent && (
|
||||
<AgentTools />
|
||||
)}
|
||||
|
||||
@ -98,7 +88,7 @@ const Config: FC = () => {
|
||||
<ConfigAudio />
|
||||
|
||||
{/* Chat History */}
|
||||
{!readonly && isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && (
|
||||
{isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && (
|
||||
<HistoryPanel
|
||||
showWarning={!hasSetBlockStatus.history}
|
||||
onShowEditModal={showHistoryModal}
|
||||
|
||||
@ -30,7 +30,6 @@ const Item: FC<ItemProps> = ({
|
||||
config,
|
||||
onSave,
|
||||
onRemove,
|
||||
readonly = false,
|
||||
editable = true,
|
||||
}) => {
|
||||
const media = useBreakpoints()
|
||||
@ -57,7 +56,6 @@ const Item: FC<ItemProps> = ({
|
||||
<div className={cn(
|
||||
'group relative mb-1 flex h-10 w-full cursor-pointer items-center justify-between rounded-lg border-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg px-2 last-of-type:mb-0 hover:bg-components-panel-on-panel-item-bg-hover',
|
||||
isDeleting && 'border-state-destructive-border hover:bg-state-destructive-hover',
|
||||
readonly && 'cursor-not-allowed',
|
||||
)}
|
||||
>
|
||||
<div className="flex w-0 grow items-center space-x-1.5">
|
||||
@ -72,7 +70,7 @@ const Item: FC<ItemProps> = ({
|
||||
</div>
|
||||
<div className="ml-2 hidden shrink-0 items-center space-x-1 group-hover:flex">
|
||||
{
|
||||
editable && !readonly && (
|
||||
editable && (
|
||||
<ActionButton
|
||||
onClick={(e) => {
|
||||
e.stopPropagation()
|
||||
@ -83,18 +81,14 @@ const Item: FC<ItemProps> = ({
|
||||
</ActionButton>
|
||||
)
|
||||
}
|
||||
{
|
||||
!readonly && (
|
||||
<ActionButton
|
||||
onClick={() => onRemove(config.id)}
|
||||
state={isDeleting ? ActionButtonState.Destructive : ActionButtonState.Default}
|
||||
onMouseEnter={() => setIsDeleting(true)}
|
||||
onMouseLeave={() => setIsDeleting(false)}
|
||||
>
|
||||
<RiDeleteBinLine className={cn('h-4 w-4 shrink-0 text-text-tertiary', isDeleting && 'text-text-destructive')} />
|
||||
</ActionButton>
|
||||
)
|
||||
}
|
||||
<ActionButton
|
||||
onClick={() => onRemove(config.id)}
|
||||
state={isDeleting ? ActionButtonState.Destructive : ActionButtonState.Default}
|
||||
onMouseEnter={() => setIsDeleting(true)}
|
||||
onMouseLeave={() => setIsDeleting(false)}
|
||||
>
|
||||
<RiDeleteBinLine className={cn('h-4 w-4 shrink-0 text-text-tertiary', isDeleting && 'text-text-destructive')} />
|
||||
</ActionButton>
|
||||
</div>
|
||||
{
|
||||
config.indexing_technique && (
|
||||
@ -113,13 +107,11 @@ const Item: FC<ItemProps> = ({
|
||||
)
|
||||
}
|
||||
<Drawer isOpen={showSettingsModal} onClose={() => setShowSettingsModal(false)} footer={null} mask={isMobile} panelClassName="mt-16 mx-2 sm:mr-2 mb-3 !p-0 !max-w-[640px] rounded-xl">
|
||||
{showSettingsModal && (
|
||||
<SettingsModal
|
||||
currentDataset={config}
|
||||
onCancel={() => setShowSettingsModal(false)}
|
||||
onSave={handleSave}
|
||||
/>
|
||||
)}
|
||||
<SettingsModal
|
||||
currentDataset={config}
|
||||
onCancel={() => setShowSettingsModal(false)}
|
||||
onSave={handleSave}
|
||||
/>
|
||||
</Drawer>
|
||||
</div>
|
||||
)
|
||||
|
||||
@ -30,7 +30,6 @@ import {
|
||||
import { useSelector as useAppContextSelector } from '@/context/app-context'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { hasEditPermissionForDataset } from '@/utils/permission'
|
||||
import FeaturePanel from '../base/feature-panel'
|
||||
import OperationBtn from '../base/operation-btn'
|
||||
@ -39,11 +38,7 @@ import CardItem from './card-item'
|
||||
import ContextVar from './context-var'
|
||||
import ParamsConfig from './params-config'
|
||||
|
||||
type Props = {
|
||||
readonly?: boolean
|
||||
hideMetadataFilter?: boolean
|
||||
}
|
||||
const DatasetConfig: FC<Props> = ({ readonly, hideMetadataFilter }) => {
|
||||
const DatasetConfig: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const userProfile = useAppContextSelector(s => s.userProfile)
|
||||
const {
|
||||
@ -264,19 +259,17 @@ const DatasetConfig: FC<Props> = ({ readonly, hideMetadataFilter }) => {
|
||||
className="mt-2"
|
||||
title={t('feature.dataSet.title', { ns: 'appDebug' })}
|
||||
headerRight={(
|
||||
!readonly && (
|
||||
<div className="flex items-center gap-1">
|
||||
{!isAgent && <ParamsConfig disabled={!hasData} selectedDatasets={dataSet} />}
|
||||
<OperationBtn type="add" onClick={showSelectDataSet} />
|
||||
</div>
|
||||
)
|
||||
<div className="flex items-center gap-1">
|
||||
{!isAgent && <ParamsConfig disabled={!hasData} selectedDatasets={dataSet} />}
|
||||
<OperationBtn type="add" onClick={showSelectDataSet} />
|
||||
</div>
|
||||
)}
|
||||
hasHeaderBottomBorder={!hasData}
|
||||
noBodySpacing
|
||||
>
|
||||
{hasData
|
||||
? (
|
||||
<div className={cn('mt-1 grid grid-cols-1 px-3 pb-3', readonly && 'grid-cols-2 gap-1')}>
|
||||
<div className="mt-1 flex flex-wrap justify-between px-3 pb-3">
|
||||
{formattedDataset.map(item => (
|
||||
<CardItem
|
||||
key={item.id}
|
||||
@ -284,7 +277,6 @@ const DatasetConfig: FC<Props> = ({ readonly, hideMetadataFilter }) => {
|
||||
onRemove={onRemove}
|
||||
onSave={handleSave}
|
||||
editable={item.editable}
|
||||
readonly={readonly}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
@ -295,29 +287,27 @@ const DatasetConfig: FC<Props> = ({ readonly, hideMetadataFilter }) => {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{!hideMetadataFilter && (
|
||||
<div className="border-t border-t-divider-subtle py-2">
|
||||
<MetadataFilter
|
||||
metadataList={metadataList}
|
||||
selectedDatasetsLoaded
|
||||
metadataFilterMode={datasetConfigs.metadata_filtering_mode}
|
||||
metadataFilteringConditions={datasetConfigs.metadata_filtering_conditions}
|
||||
handleAddCondition={handleAddCondition}
|
||||
handleMetadataFilterModeChange={handleMetadataFilterModeChange}
|
||||
handleRemoveCondition={handleRemoveCondition}
|
||||
handleToggleConditionLogicalOperator={handleToggleConditionLogicalOperator}
|
||||
handleUpdateCondition={handleUpdateCondition}
|
||||
metadataModelConfig={datasetConfigs.metadata_model_config}
|
||||
handleMetadataModelChange={handleMetadataModelChange}
|
||||
handleMetadataCompletionParamsChange={handleMetadataCompletionParamsChange}
|
||||
isCommonVariable
|
||||
availableCommonStringVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.string || item.type === MetadataFilteringVariableType.select)}
|
||||
availableCommonNumberVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.number)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<div className="border-t border-t-divider-subtle py-2">
|
||||
<MetadataFilter
|
||||
metadataList={metadataList}
|
||||
selectedDatasetsLoaded
|
||||
metadataFilterMode={datasetConfigs.metadata_filtering_mode}
|
||||
metadataFilteringConditions={datasetConfigs.metadata_filtering_conditions}
|
||||
handleAddCondition={handleAddCondition}
|
||||
handleMetadataFilterModeChange={handleMetadataFilterModeChange}
|
||||
handleRemoveCondition={handleRemoveCondition}
|
||||
handleToggleConditionLogicalOperator={handleToggleConditionLogicalOperator}
|
||||
handleUpdateCondition={handleUpdateCondition}
|
||||
metadataModelConfig={datasetConfigs.metadata_model_config}
|
||||
handleMetadataModelChange={handleMetadataModelChange}
|
||||
handleMetadataCompletionParamsChange={handleMetadataCompletionParamsChange}
|
||||
isCommonVariable
|
||||
availableCommonStringVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.string || item.type === MetadataFilteringVariableType.select)}
|
||||
availableCommonNumberVars={promptVariablesToSelect.filter(item => item.type === MetadataFilteringVariableType.number)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{!readonly && mode === AppModeEnum.COMPLETION && dataSet.length > 0 && (
|
||||
{mode === AppModeEnum.COMPLETION && dataSet.length > 0 && (
|
||||
<ContextVar
|
||||
value={selectedContextVar?.key}
|
||||
options={promptVariablesToSelect}
|
||||
|
||||
@ -7,7 +7,6 @@ import Input from '@/app/components/base/input'
|
||||
import Select from '@/app/components/base/select'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
|
||||
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@ -19,7 +18,7 @@ const ChatUserInput = ({
|
||||
inputs,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const { modelConfig, setInputs, readonly } = useContext(ConfigContext)
|
||||
const { modelConfig, setInputs } = useContext(ConfigContext)
|
||||
|
||||
const promptVariables = modelConfig.configs.prompt_variables.filter(({ key, name }) => {
|
||||
return key && key?.trim() && name && name?.trim()
|
||||
@ -88,8 +87,7 @@ const ChatUserInput = ({
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
placeholder={name}
|
||||
autoFocus={index === 0}
|
||||
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
|
||||
readOnly={readonly}
|
||||
maxLength={max_length}
|
||||
/>
|
||||
)}
|
||||
{type === 'paragraph' && (
|
||||
@ -98,7 +96,6 @@ const ChatUserInput = ({
|
||||
placeholder={name}
|
||||
value={inputs[key] ? `${inputs[key]}` : ''}
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
readOnly={readonly}
|
||||
/>
|
||||
)}
|
||||
{type === 'select' && (
|
||||
@ -108,7 +105,6 @@ const ChatUserInput = ({
|
||||
onSelect={(i) => { handleInputValueChange(key, i.value as string) }}
|
||||
items={(options || []).map(i => ({ name: i, value: i }))}
|
||||
allowSearch={false}
|
||||
disabled={readonly}
|
||||
/>
|
||||
)}
|
||||
{type === 'number' && (
|
||||
@ -118,8 +114,7 @@ const ChatUserInput = ({
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
placeholder={name}
|
||||
autoFocus={index === 0}
|
||||
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
|
||||
readOnly={readonly}
|
||||
maxLength={max_length}
|
||||
/>
|
||||
)}
|
||||
{type === 'checkbox' && (
|
||||
@ -128,7 +123,6 @@ const ChatUserInput = ({
|
||||
value={!!inputs[key]}
|
||||
required={required}
|
||||
onChange={(value) => { handleInputValueChange(key, value) }}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -15,7 +15,6 @@ import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/
|
||||
import { useDebugConfigurationContext } from '@/context/debug-configuration'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
|
||||
import { APP_CHAT_WITH_MULTIPLE_MODEL } from '../types'
|
||||
|
||||
@ -131,11 +130,11 @@ const TextGenerationItem: FC<TextGenerationItemProps> = ({
|
||||
|
||||
return (
|
||||
<TextGeneration
|
||||
appSourceType={AppSourceType.webApp}
|
||||
className="flex h-full flex-col overflow-y-auto border-none"
|
||||
content={completion}
|
||||
isLoading={!completion && isResponding}
|
||||
isResponding={isResponding}
|
||||
isInstalledApp={false}
|
||||
siteInfo={null}
|
||||
messageId={messageId}
|
||||
isError={false}
|
||||
|
||||
@ -39,7 +39,6 @@ const DebugWithSingleModel = (
|
||||
) => {
|
||||
const { userProfile } = useAppContext()
|
||||
const {
|
||||
readonly,
|
||||
modelConfig,
|
||||
appId,
|
||||
inputs,
|
||||
@ -151,7 +150,6 @@ const DebugWithSingleModel = (
|
||||
|
||||
return (
|
||||
<Chat
|
||||
readonly={readonly}
|
||||
config={config}
|
||||
chatList={chatList}
|
||||
isResponding={isResponding}
|
||||
|
||||
@ -38,7 +38,6 @@ import ConfigContext from '@/context/debug-configuration'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { sendCompletionMessage } from '@/service/debug'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { AppModeEnum, ModelModeType, TransferMethod } from '@/types/app'
|
||||
import { formatBooleanInputs, promptVariablesToUserInputsForm } from '@/utils/model-config'
|
||||
import GroupName from '../base/group-name'
|
||||
@ -73,7 +72,6 @@ const Debug: FC<IDebug> = ({
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
readonly,
|
||||
appId,
|
||||
mode,
|
||||
modelModeType,
|
||||
@ -418,33 +416,25 @@ const Debug: FC<IDebug> = ({
|
||||
}
|
||||
{mode !== AppModeEnum.COMPLETION && (
|
||||
<>
|
||||
{
|
||||
!readonly && (
|
||||
<TooltipPlus
|
||||
popupContent={t('operation.refresh', { ns: 'common' })}
|
||||
>
|
||||
<ActionButton onClick={clearConversation}>
|
||||
<RefreshCcw01 className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</TooltipPlus>
|
||||
{varList.length > 0 && (
|
||||
<div className="relative ml-1 mr-2">
|
||||
<TooltipPlus
|
||||
popupContent={t('operation.refresh', { ns: 'common' })}
|
||||
popupContent={t('panel.userInputField', { ns: 'workflow' })}
|
||||
>
|
||||
<ActionButton onClick={clearConversation}>
|
||||
<RefreshCcw01 className="h-4 w-4" />
|
||||
<ActionButton state={expanded ? ActionButtonState.Active : undefined} onClick={() => setExpanded(!expanded)}>
|
||||
<RiEqualizer2Line className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
|
||||
</TooltipPlus>
|
||||
)
|
||||
}
|
||||
|
||||
{
|
||||
varList.length > 0 && (
|
||||
<div className="relative ml-1 mr-2">
|
||||
<TooltipPlus
|
||||
popupContent={t('panel.userInputField', { ns: 'workflow' })}
|
||||
>
|
||||
<ActionButton state={expanded ? ActionButtonState.Active : undefined} onClick={() => !readonly && setExpanded(!expanded)}>
|
||||
<RiEqualizer2Line className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
</TooltipPlus>
|
||||
{expanded && <div className="absolute bottom-[-14px] right-[5px] z-10 h-3 w-3 rotate-45 border-l-[0.5px] border-t-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg" />}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{expanded && <div className="absolute bottom-[-14px] right-[5px] z-10 h-3 w-3 rotate-45 border-l-[0.5px] border-t-[0.5px] border-components-panel-border-subtle bg-components-panel-on-panel-item-bg" />}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@ -454,21 +444,19 @@ const Debug: FC<IDebug> = ({
|
||||
<ChatUserInput inputs={inputs} />
|
||||
</div>
|
||||
)}
|
||||
{
|
||||
mode === AppModeEnum.COMPLETION && (
|
||||
<PromptValuePanel
|
||||
appType={mode as AppModeEnum}
|
||||
onSend={handleSendTextCompletion}
|
||||
inputs={inputs}
|
||||
visionConfig={{
|
||||
...features.file! as VisionSettings,
|
||||
transfer_methods: features.file!.allowed_file_upload_methods || [],
|
||||
image_file_size_limit: features.file?.fileUploadConfig?.image_file_size_limit,
|
||||
}}
|
||||
onVisionFilesChange={setCompletionFiles}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{mode === AppModeEnum.COMPLETION && (
|
||||
<PromptValuePanel
|
||||
appType={mode as AppModeEnum}
|
||||
onSend={handleSendTextCompletion}
|
||||
inputs={inputs}
|
||||
visionConfig={{
|
||||
...features.file! as VisionSettings,
|
||||
transfer_methods: features.file!.allowed_file_upload_methods || [],
|
||||
image_file_size_limit: features.file?.fileUploadConfig?.image_file_size_limit,
|
||||
}}
|
||||
onVisionFilesChange={setCompletionFiles}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{
|
||||
debugWithMultipleModel && (
|
||||
@ -522,12 +510,12 @@ const Debug: FC<IDebug> = ({
|
||||
<div className="mx-4 mt-3"><GroupName name={t('result', { ns: 'appDebug' })} /></div>
|
||||
<div className="mx-3 mb-8">
|
||||
<TextGeneration
|
||||
appSourceType={AppSourceType.webApp}
|
||||
className="mt-2"
|
||||
content={completionRes}
|
||||
isLoading={!completionRes && isResponding}
|
||||
isShowTextToSpeech={textToSpeechConfig.enabled && !!text2speechDefaultModel}
|
||||
isResponding={isResponding}
|
||||
isInstalledApp={false}
|
||||
messageId={messageId}
|
||||
isError={false}
|
||||
onRetry={noop}
|
||||
@ -562,15 +550,13 @@ const Debug: FC<IDebug> = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
isShowFormattingChangeConfirm && (
|
||||
<FormattingChanged
|
||||
onConfirm={handleConfirm}
|
||||
onCancel={handleCancel}
|
||||
/>
|
||||
)
|
||||
}
|
||||
{!isAPIKeySet && !readonly && (<HasNotSetAPIKEY isTrailFinished={!IS_CE_EDITION} onSetting={onSetting} />)}
|
||||
{isShowFormattingChangeConfirm && (
|
||||
<FormattingChanged
|
||||
onConfirm={handleConfirm}
|
||||
onCancel={handleCancel}
|
||||
/>
|
||||
)}
|
||||
{!isAPIKeySet && (<HasNotSetAPIKEY isTrailFinished={!IS_CE_EDITION} onSetting={onSetting} />)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
@ -20,7 +20,6 @@ import Select from '@/app/components/base/select'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import BoolInput from '@/app/components/workflow/nodes/_base/components/before-run-form/bool-input'
|
||||
import { DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { AppModeEnum, ModelModeType } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@ -41,7 +40,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
onVisionFilesChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { readonly, modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext)
|
||||
const { modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext)
|
||||
const [userInputFieldCollapse, setUserInputFieldCollapse] = useState(false)
|
||||
const promptVariables = modelConfig.configs.prompt_variables.filter(({ key, name }) => {
|
||||
return key && key?.trim() && name && name?.trim()
|
||||
@ -79,12 +78,12 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
|
||||
if (isAdvancedMode) {
|
||||
if (modelModeType === ModelModeType.chat)
|
||||
return chatPromptConfig?.prompt.every(({ text }) => !text)
|
||||
return chatPromptConfig.prompt.every(({ text }) => !text)
|
||||
return !completionPromptConfig.prompt?.text
|
||||
}
|
||||
|
||||
else { return !modelConfig.configs.prompt_template }
|
||||
}, [chatPromptConfig?.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType])
|
||||
}, [chatPromptConfig.prompt, completionPromptConfig.prompt?.text, isAdvancedMode, mode, modelConfig.configs.prompt_template, modelModeType])
|
||||
|
||||
const handleInputValueChange = (key: string, value: string | boolean) => {
|
||||
if (!(key in promptVariableObj))
|
||||
@ -142,8 +141,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
placeholder={name}
|
||||
autoFocus={index === 0}
|
||||
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
|
||||
readOnly={readonly}
|
||||
maxLength={max_length}
|
||||
/>
|
||||
)}
|
||||
{type === 'paragraph' && (
|
||||
@ -152,7 +150,6 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
placeholder={name}
|
||||
value={inputs[key] ? `${inputs[key]}` : ''}
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
readOnly={readonly}
|
||||
/>
|
||||
)}
|
||||
{type === 'select' && (
|
||||
@ -163,7 +160,6 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
items={(options || []).map(i => ({ name: i, value: i }))}
|
||||
allowSearch={false}
|
||||
bgClassName="bg-gray-50"
|
||||
disabled={readonly}
|
||||
/>
|
||||
)}
|
||||
{type === 'number' && (
|
||||
@ -173,8 +169,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
onChange={(e) => { handleInputValueChange(key, e.target.value) }}
|
||||
placeholder={name}
|
||||
autoFocus={index === 0}
|
||||
maxLength={max_length || DEFAULT_VALUE_MAX_LEN}
|
||||
readOnly={readonly}
|
||||
maxLength={max_length}
|
||||
/>
|
||||
)}
|
||||
{type === 'checkbox' && (
|
||||
@ -183,7 +178,6 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
value={!!inputs[key]}
|
||||
required={required}
|
||||
onChange={(value) => { handleInputValueChange(key, value) }}
|
||||
readonly={readonly}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
@ -202,7 +196,6 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
url: fileItem.url,
|
||||
upload_file_id: fileItem.fileId,
|
||||
})))}
|
||||
disabled={readonly}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@ -211,12 +204,12 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
)}
|
||||
{!userInputFieldCollapse && (
|
||||
<div className="flex justify-between border-t border-divider-subtle p-4 pt-3">
|
||||
<Button className="w-[72px]" disabled={readonly} onClick={onClear}>{t('operation.clear', { ns: 'common' })}</Button>
|
||||
<Button className="w-[72px]" onClick={onClear}>{t('operation.clear', { ns: 'common' })}</Button>
|
||||
{canNotRun && (
|
||||
<Tooltip popupContent={t('otherError.promptNoBeEmpty', { ns: 'appDebug' })}>
|
||||
<Button
|
||||
variant="primary"
|
||||
disabled={canNotRun || readonly}
|
||||
disabled={canNotRun}
|
||||
onClick={() => onSend?.()}
|
||||
className="w-[96px]"
|
||||
>
|
||||
@ -228,7 +221,7 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
{!canNotRun && (
|
||||
<Button
|
||||
variant="primary"
|
||||
disabled={canNotRun || readonly}
|
||||
disabled={canNotRun}
|
||||
onClick={() => onSend?.()}
|
||||
className="w-[96px]"
|
||||
>
|
||||
@ -244,8 +237,6 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
showFileUpload={false}
|
||||
isChatMode={appType !== AppModeEnum.COMPLETION}
|
||||
onFeatureBarClick={setShowAppConfigureFeaturesModal}
|
||||
disabled={readonly}
|
||||
hideEditEntrance={readonly}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
|
||||
@ -10,7 +10,6 @@ vi.mock('@heroicons/react/20/solid', () => ({
|
||||
}))
|
||||
|
||||
const mockApp: App = {
|
||||
can_trial: true,
|
||||
app: {
|
||||
id: 'test-app-id',
|
||||
mode: AppModeEnum.CHAT,
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
'use client'
|
||||
import type { App } from '@/models/explore'
|
||||
import { PlusIcon } from '@heroicons/react/20/solid'
|
||||
import { RiInformation2Line } from '@remixicon/react'
|
||||
import { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContextSelector } from 'use-context-selector'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import Button from '@/app/components/base/button'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { AppTypeIcon, AppTypeLabel } from '../../type-selector'
|
||||
|
||||
@ -25,14 +20,6 @@ const AppCard = ({
|
||||
}: AppCardProps) => {
|
||||
const { t } = useTranslation()
|
||||
const { app: appBasicInfo } = app
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const isTrialApp = app.can_trial && systemFeatures.enable_trial_app
|
||||
const setShowTryAppPanel = useContextSelector(AppListContext, ctx => ctx.setShowTryAppPanel)
|
||||
const showTryAPPPanel = useCallback((appId: string) => {
|
||||
return () => {
|
||||
setShowTryAppPanel?.(true, { appId, app })
|
||||
}
|
||||
}, [setShowTryAppPanel, app.category])
|
||||
return (
|
||||
<div className={cn('group relative flex h-[132px] cursor-pointer flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg p-4 shadow-xs hover:shadow-lg')}>
|
||||
<div className="flex shrink-0 grow-0 items-center gap-3 pb-2">
|
||||
@ -64,17 +51,11 @@ const AppCard = ({
|
||||
</div>
|
||||
{canCreate && (
|
||||
<div className={cn('absolute bottom-0 left-0 right-0 hidden bg-gradient-to-t from-components-panel-gradient-2 from-[60.27%] to-transparent p-4 pt-8 group-hover:flex')}>
|
||||
<div className={cn('grid h-8 w-full grid-cols-1 items-center space-x-2', isTrialApp && 'grid-cols-2')}>
|
||||
<Button variant="primary" onClick={() => onCreate()}>
|
||||
<div className={cn('flex h-8 w-full items-center space-x-2')}>
|
||||
<Button variant="primary" className="grow" onClick={() => onCreate()}>
|
||||
<PlusIcon className="mr-1 h-4 w-4" />
|
||||
<span className="text-xs">{t('newApp.useTemplate', { ns: 'app' })}</span>
|
||||
</Button>
|
||||
{isTrialApp && (
|
||||
<Button onClick={showTryAPPPanel(app.app_id)}>
|
||||
<RiInformation2Line className="mr-1 size-4" />
|
||||
<span>{t('appCard.try', { ns: 'explore' })}</span>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -39,7 +39,6 @@ import { useAppContext } from '@/context/app-context'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import useTimestamp from '@/hooks/use-timestamp'
|
||||
import { fetchChatMessages, updateLogMessageAnnotations, updateLogMessageFeedbacks } from '@/service/log'
|
||||
import { AppSourceType } from '@/service/share'
|
||||
import { useChatConversationDetail, useCompletionConversationDetail } from '@/service/use-log'
|
||||
import { AppModeEnum } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
@ -89,7 +88,16 @@ const HandThumbIconWithCount: FC<{ count: number, iconType: 'up' | 'down' }> = (
|
||||
)
|
||||
}
|
||||
|
||||
const statusTdRender = (statusCount: StatusCount) => {
|
||||
const statusTdRender = (status: 'normal' | 'finished' | 'paused', statusCount: StatusCount) => {
|
||||
if (status === 'paused') {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<Indicator color="yellow" />
|
||||
<span className="text-util-colors-warning-warning-600">Pending</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
if (!statusCount)
|
||||
return null
|
||||
|
||||
@ -296,7 +304,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
if (abortControllerRef.current === controller)
|
||||
abortControllerRef.current = null
|
||||
}
|
||||
}, [detail.id, hasMore, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
|
||||
}, [detail.id, hasMore, timezone, t, appDetail])
|
||||
|
||||
// Derive chatItemTree, threadChatItems, and oldestAnswerIdRef from allChatItems
|
||||
useEffect(() => {
|
||||
@ -411,7 +419,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
notify({ type: 'error', message: t('actionMsg.modifiedUnsuccessfully', { ns: 'common' }) })
|
||||
return false
|
||||
}
|
||||
}, [allChatItems, appDetail?.id, t])
|
||||
}, [allChatItems, appDetail?.id, notify, t])
|
||||
|
||||
const fetchInitiated = useRef(false)
|
||||
|
||||
@ -504,7 +512,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
finally {
|
||||
setIsLoading(false)
|
||||
}
|
||||
}, [detail.id, hasMore, isLoading, timezone, t, appDetail, detail?.model_config?.configs?.introduction])
|
||||
}, [detail.id, hasMore, isLoading, timezone, t, appDetail])
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = document.getElementById('scrollableDiv')
|
||||
@ -639,12 +647,12 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
|
||||
</div>
|
||||
</div>
|
||||
<TextGeneration
|
||||
appSourceType={AppSourceType.webApp}
|
||||
className="mt-2"
|
||||
content={detail.message.answer}
|
||||
messageId={detail.message.id}
|
||||
isError={false}
|
||||
onRetry={noop}
|
||||
isInstalledApp={false}
|
||||
supportFeedback
|
||||
feedback={detail.message.feedbacks.find((item: any) => item.from_source === 'admin')}
|
||||
onFeedback={feedback => onFeedback(detail.message.id, feedback)}
|
||||
@ -1030,7 +1038,7 @@ const ConversationList: FC<IConversationList> = ({ logs, appDetail, onRefresh })
|
||||
<td className="p-3 pr-2">{renderTdValue(endUser || defaultValue, !endUser)}</td>
|
||||
{isChatflow && (
|
||||
<td className="w-[160px] p-3 pr-2" style={{ maxWidth: isChatMode ? 300 : 200 }}>
|
||||
{statusTdRender(log.status_count)}
|
||||
{statusTdRender(log.status, log.status_count)}
|
||||
</td>
|
||||
)}
|
||||
<td className="p-3 pr-2" style={{ maxWidth: isChatMode ? 100 : 200 }}>
|
||||
|
||||
@ -53,6 +53,7 @@ const defaultProviderContext = {
|
||||
refreshLicenseLimit: noop,
|
||||
isAllowTransferWorkspace: false,
|
||||
isAllowPublishAsCustomKnowledgePipelineTemplate: false,
|
||||
humanInputEmailDeliveryEnabled: false,
|
||||
}
|
||||
|
||||
const defaultModalContext: ModalContextState = {
|
||||
|
||||
@ -8,7 +8,7 @@ import {
|
||||
RiClipboardLine,
|
||||
RiFileList3Line,
|
||||
RiPlayList2Line,
|
||||
RiReplay15Line,
|
||||
RiResetLeftLine,
|
||||
RiSparklingFill,
|
||||
RiSparklingLine,
|
||||
RiThumbDownLine,
|
||||
@ -18,10 +18,12 @@ import { useBoolean } from 'ahooks'
|
||||
import copy from 'copy-to-clipboard'
|
||||
import { useParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||
import ActionButton, { ActionButtonState } from '@/app/components/base/action-button'
|
||||
import HumanInputFilledFormList from '@/app/components/base/chat/chat/answer/human-input-filled-form-list'
|
||||
import HumanInputFormList from '@/app/components/base/chat/chat/answer/human-input-form-list'
|
||||
import WorkflowProcessItem from '@/app/components/base/chat/chat/answer/workflow-process'
|
||||
import { useChatContext } from '@/app/components/base/chat/chat/context'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
@ -29,7 +31,8 @@ import { Markdown } from '@/app/components/base/markdown'
|
||||
import NewAudioButton from '@/app/components/base/new-audio-button'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { fetchTextGenerationMessage } from '@/service/debug'
|
||||
import { AppSourceType, fetchMoreLikeThis, updateFeedback } from '@/service/share'
|
||||
import { fetchMoreLikeThis, submitHumanInputForm, updateFeedback } from '@/service/share'
|
||||
import { submitHumanInputForm as submitHumanInputFormService } from '@/service/workflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import ResultTab from './result-tab'
|
||||
|
||||
@ -53,7 +56,7 @@ export type IGenerationItemProps = {
|
||||
onFeedback?: (feedback: FeedbackType) => void
|
||||
onSave?: (messageId: string) => void
|
||||
isMobile?: boolean
|
||||
appSourceType: AppSourceType
|
||||
isInstalledApp: boolean
|
||||
installedAppId?: string
|
||||
taskId?: string
|
||||
controlClearMoreLikeThis?: number
|
||||
@ -87,7 +90,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
onSave,
|
||||
depth = 1,
|
||||
isMobile,
|
||||
appSourceType,
|
||||
isInstalledApp,
|
||||
installedAppId,
|
||||
taskId,
|
||||
controlClearMoreLikeThis,
|
||||
@ -100,7 +103,6 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
const { t } = useTranslation()
|
||||
const params = useParams()
|
||||
const isTop = depth === 1
|
||||
const isTryApp = appSourceType === AppSourceType.tryApp
|
||||
const [completionRes, setCompletionRes] = useState('')
|
||||
const [childMessageId, setChildMessageId] = useState<string | null>(null)
|
||||
const [childFeedback, setChildFeedback] = useState<FeedbackType>({
|
||||
@ -114,14 +116,14 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
const setShowPromptLogModal = useAppStore(s => s.setShowPromptLogModal)
|
||||
|
||||
const handleFeedback = async (childFeedback: FeedbackType) => {
|
||||
await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, appSourceType, installedAppId)
|
||||
await updateFeedback({ url: `/messages/${childMessageId}/feedbacks`, body: { rating: childFeedback.rating } }, isInstalledApp, installedAppId)
|
||||
setChildFeedback(childFeedback)
|
||||
}
|
||||
|
||||
const [isQuerying, { setTrue: startQuerying, setFalse: stopQuerying }] = useBoolean(false)
|
||||
|
||||
const childProps = {
|
||||
isInWebApp: true,
|
||||
isInWebApp,
|
||||
content: completionRes,
|
||||
messageId: childMessageId,
|
||||
depth: depth + 1,
|
||||
@ -132,7 +134,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
onSave,
|
||||
isShowTextToSpeech,
|
||||
isMobile,
|
||||
appSourceType,
|
||||
isInstalledApp,
|
||||
installedAppId,
|
||||
controlClearMoreLikeThis,
|
||||
isWorkflow,
|
||||
@ -146,7 +148,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
return
|
||||
}
|
||||
startQuerying()
|
||||
const res: any = await fetchMoreLikeThis(messageId as string, appSourceType, installedAppId)
|
||||
const res: any = await fetchMoreLikeThis(messageId as string, isInstalledApp, installedAppId)
|
||||
setCompletionRes(res.answer)
|
||||
setChildFeedback({
|
||||
rating: null,
|
||||
@ -202,16 +204,22 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
}
|
||||
|
||||
const [currentTab, setCurrentTab] = useState<string>('DETAIL')
|
||||
const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length
|
||||
const showResultTabs = !!workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0)
|
||||
const switchTab = async (tab: string) => {
|
||||
setCurrentTab(tab)
|
||||
}
|
||||
useEffect(() => {
|
||||
if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length)
|
||||
if (workflowProcessData?.resultText || !!workflowProcessData?.files?.length || (workflowProcessData?.humanInputFormDataList && workflowProcessData?.humanInputFormDataList.length > 0) || (workflowProcessData?.humanInputFilledFormDataList && workflowProcessData?.humanInputFilledFormDataList.length > 0))
|
||||
switchTab('RESULT')
|
||||
else
|
||||
switchTab('DETAIL')
|
||||
}, [workflowProcessData?.files?.length, workflowProcessData?.resultText])
|
||||
}, [workflowProcessData?.files?.length, workflowProcessData?.resultText, workflowProcessData?.humanInputFormDataList, workflowProcessData?.humanInputFilledFormDataList])
|
||||
const handleSubmitHumanInputForm = useCallback(async (formToken: string, formData: any) => {
|
||||
if (isInstalledApp)
|
||||
await submitHumanInputFormService(formToken, formData)
|
||||
else
|
||||
await submitHumanInputForm(formToken, formData)
|
||||
}, [isInstalledApp])
|
||||
|
||||
return (
|
||||
<>
|
||||
@ -275,7 +283,24 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
)}
|
||||
</div>
|
||||
{!isError && (
|
||||
<ResultTab data={workflowProcessData} content={content} currentTab={currentTab} />
|
||||
<>
|
||||
{currentTab === 'RESULT' && workflowProcessData.humanInputFormDataList && workflowProcessData.humanInputFormDataList.length > 0 && (
|
||||
<div className="px-4 pt-4">
|
||||
<HumanInputFormList
|
||||
humanInputFormDataList={workflowProcessData.humanInputFormDataList}
|
||||
onHumanInputFormSubmit={handleSubmitHumanInputForm}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{currentTab === 'RESULT' && workflowProcessData.humanInputFilledFormDataList && workflowProcessData.humanInputFilledFormDataList.length > 0 && (
|
||||
<div className="px-4 pt-4">
|
||||
<HumanInputFilledFormList
|
||||
humanInputFilledFormDataList={workflowProcessData.humanInputFilledFormDataList}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
<ResultTab data={workflowProcessData} content={content} currentTab={currentTab} />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
@ -311,7 +336,7 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
)}
|
||||
{/* action buttons */}
|
||||
<div className="absolute bottom-1 right-2 flex items-center">
|
||||
{!isInWebApp && (appSourceType !== AppSourceType.installedApp) && !isResponding && (
|
||||
{!isInWebApp && !isInstalledApp && !isResponding && (
|
||||
<div className="ml-1 flex items-center gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm">
|
||||
<ActionButton disabled={isError || !messageId} onClick={handleOpenLogModal}>
|
||||
<RiFileList3Line className="h-4 w-4" />
|
||||
@ -320,12 +345,12 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
</div>
|
||||
)}
|
||||
<div className="ml-1 flex items-center gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm">
|
||||
{moreLikeThis && !isTryApp && (
|
||||
{moreLikeThis && (
|
||||
<ActionButton state={depth === MAX_DEPTH ? ActionButtonState.Disabled : ActionButtonState.Default} disabled={depth === MAX_DEPTH} onClick={handleMoreLikeThis}>
|
||||
<RiSparklingLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
)}
|
||||
{isShowTextToSpeech && !isTryApp && (
|
||||
{isShowTextToSpeech && (
|
||||
<NewAudioButton
|
||||
id={messageId!}
|
||||
voice={config?.text_to_speech?.voice}
|
||||
@ -348,16 +373,16 @@ const GenerationItem: FC<IGenerationItemProps> = ({
|
||||
)}
|
||||
{isInWebApp && isError && (
|
||||
<ActionButton onClick={onRetry}>
|
||||
<RiReplay15Line className="h-4 w-4" />
|
||||
<RiResetLeftLine className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
)}
|
||||
{isInWebApp && !isWorkflow && !isTryApp && (
|
||||
{isInWebApp && !isWorkflow && (
|
||||
<ActionButton disabled={isError || !messageId} onClick={() => { onSave?.(messageId as string) }}>
|
||||
<RiBookmark3Line className="h-4 w-4" />
|
||||
</ActionButton>
|
||||
)}
|
||||
</div>
|
||||
{(supportFeedback || isInWebApp) && !isWorkflow && !isTryApp && !isError && messageId && (
|
||||
{(supportFeedback || isInWebApp) && !isWorkflow && !isError && messageId && (
|
||||
<div className="ml-1 flex items-center gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm">
|
||||
{!feedback?.rating && (
|
||||
<>
|
||||
|
||||
@ -81,6 +81,14 @@ const WorkflowAppLogList: FC<ILogs> = ({ logs, appDetail, onRefresh }) => {
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (status === 'paused') {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
<Indicator color="yellow" />
|
||||
<span className="text-util-colors-warning-warning-600">Pending</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (status === 'running') {
|
||||
return (
|
||||
<div className="system-xs-semibold-uppercase inline-flex items-center gap-1">
|
||||
|
||||
@ -1,17 +1,7 @@
|
||||
'use client'
|
||||
import type { CreateAppModalProps } from '../explore/create-app-modal'
|
||||
import type { CurrentTryAppParams } from '@/context/explore-context'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useEducationInit } from '@/app/education-apply/hooks'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import useDocumentTitle from '@/hooks/use-document-title'
|
||||
import { useImportDSL } from '@/hooks/use-import-dsl'
|
||||
import { DSLImportMode } from '@/models/app'
|
||||
import { fetchAppDetail } from '@/service/explore'
|
||||
import DSLConfirmModal from '../app/create-from-dsl-modal/dsl-confirm-modal'
|
||||
import CreateAppModal from '../explore/create-app-modal'
|
||||
import TryApp from '../explore/try-app'
|
||||
import List from './list'
|
||||
|
||||
const Apps = () => {
|
||||
@ -20,124 +10,10 @@ const Apps = () => {
|
||||
useDocumentTitle(t('menus.apps', { ns: 'common' }))
|
||||
useEducationInit()
|
||||
|
||||
const [currentTryAppParams, setCurrentTryAppParams] = useState<CurrentTryAppParams | undefined>(undefined)
|
||||
const currApp = currentTryAppParams?.app
|
||||
const [isShowTryAppPanel, setIsShowTryAppPanel] = useState(false)
|
||||
const hideTryAppPanel = useCallback(() => {
|
||||
setIsShowTryAppPanel(false)
|
||||
}, [])
|
||||
const setShowTryAppPanel = (showTryAppPanel: boolean, params?: CurrentTryAppParams) => {
|
||||
if (showTryAppPanel)
|
||||
setCurrentTryAppParams(params)
|
||||
else
|
||||
setCurrentTryAppParams(undefined)
|
||||
setIsShowTryAppPanel(showTryAppPanel)
|
||||
}
|
||||
const [isShowCreateModal, setIsShowCreateModal] = useState(false)
|
||||
|
||||
const handleShowFromTryApp = useCallback(() => {
|
||||
setIsShowCreateModal(true)
|
||||
}, [])
|
||||
|
||||
const [controlRefreshList, setControlRefreshList] = useState(0)
|
||||
const [controlHideCreateFromTemplatePanel, setControlHideCreateFromTemplatePanel] = useState(0)
|
||||
const onSuccess = useCallback(() => {
|
||||
setControlRefreshList(prev => prev + 1)
|
||||
setControlHideCreateFromTemplatePanel(prev => prev + 1)
|
||||
}, [])
|
||||
|
||||
const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false)
|
||||
|
||||
const {
|
||||
handleImportDSL,
|
||||
handleImportDSLConfirm,
|
||||
versions,
|
||||
isFetching,
|
||||
} = useImportDSL()
|
||||
|
||||
const onConfirmDSL = useCallback(async () => {
|
||||
await handleImportDSLConfirm({
|
||||
onSuccess,
|
||||
})
|
||||
}, [handleImportDSLConfirm, onSuccess])
|
||||
|
||||
const onCreate: CreateAppModalProps['onConfirm'] = async ({
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
icon_background,
|
||||
description,
|
||||
}) => {
|
||||
hideTryAppPanel()
|
||||
|
||||
const { export_data } = await fetchAppDetail(
|
||||
currApp?.app.id as string,
|
||||
)
|
||||
const payload = {
|
||||
mode: DSLImportMode.YAML_CONTENT,
|
||||
yaml_content: export_data,
|
||||
name,
|
||||
icon_type,
|
||||
icon,
|
||||
icon_background,
|
||||
description,
|
||||
}
|
||||
await handleImportDSL(payload, {
|
||||
onSuccess: () => {
|
||||
setIsShowCreateModal(false)
|
||||
},
|
||||
onPending: () => {
|
||||
setShowDSLConfirmModal(true)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return (
|
||||
<AppListContext.Provider value={{
|
||||
currentApp: currentTryAppParams,
|
||||
isShowTryAppPanel,
|
||||
setShowTryAppPanel,
|
||||
controlHideCreateFromTemplatePanel,
|
||||
}}
|
||||
>
|
||||
<div className="relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body">
|
||||
<List controlRefreshList={controlRefreshList} />
|
||||
{isShowTryAppPanel && (
|
||||
<TryApp
|
||||
appId={currentTryAppParams?.appId || ''}
|
||||
category={currentTryAppParams?.app?.category}
|
||||
onClose={hideTryAppPanel}
|
||||
onCreate={handleShowFromTryApp}
|
||||
/>
|
||||
)}
|
||||
|
||||
{
|
||||
showDSLConfirmModal && (
|
||||
<DSLConfirmModal
|
||||
versions={versions}
|
||||
onCancel={() => setShowDSLConfirmModal(false)}
|
||||
onConfirm={onConfirmDSL}
|
||||
confirmDisabled={isFetching}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
{isShowCreateModal && (
|
||||
<CreateAppModal
|
||||
appIconType={currApp?.app.icon_type || 'emoji'}
|
||||
appIcon={currApp?.app.icon || ''}
|
||||
appIconBackground={currApp?.app.icon_background || ''}
|
||||
appIconUrl={currApp?.app.icon_url}
|
||||
appName={currApp?.app.name || ''}
|
||||
appDescription={currApp?.app.description || ''}
|
||||
show
|
||||
onConfirm={onCreate}
|
||||
confirmDisabled={isFetching}
|
||||
onHide={() => setIsShowCreateModal(false)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</AppListContext.Provider>
|
||||
<div className="relative flex h-0 shrink-0 grow flex-col overflow-y-auto bg-background-body">
|
||||
<List />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import {
|
||||
RiApps2Line,
|
||||
RiDragDropLine,
|
||||
@ -54,12 +53,7 @@ const CreateFromDSLModal = dynamic(() => import('@/app/components/app/create-fro
|
||||
ssr: false,
|
||||
})
|
||||
|
||||
type Props = {
|
||||
controlRefreshList?: number
|
||||
}
|
||||
const List: FC<Props> = ({
|
||||
controlRefreshList = 0,
|
||||
}) => {
|
||||
const List = () => {
|
||||
const { t } = useTranslation()
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const router = useRouter()
|
||||
@ -116,13 +110,6 @@ const List: FC<Props> = ({
|
||||
refetch,
|
||||
} = useInfiniteAppList(appListQueryParams, { enabled: !isCurrentWorkspaceDatasetOperator })
|
||||
|
||||
useEffect(() => {
|
||||
if (controlRefreshList > 0) {
|
||||
refetch()
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [controlRefreshList])
|
||||
|
||||
const anchorRef = useRef<HTMLDivElement>(null)
|
||||
const options = [
|
||||
{ value: 'all', text: t('types.all', { ns: 'app' }), icon: <RiApps2Line className="mr-1 h-[14px] w-[14px]" /> },
|
||||
|
||||
@ -6,12 +6,10 @@ import {
|
||||
useSearchParams,
|
||||
} from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContextSelector } from 'use-context-selector'
|
||||
import { CreateFromDSLModalTab } from '@/app/components/app/create-from-dsl-modal'
|
||||
import { FileArrow01, FilePlus01, FilePlus02 } from '@/app/components/base/icons/src/vender/line/files'
|
||||
import AppListContext from '@/context/app-list-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
@ -57,12 +55,6 @@ const CreateAppCard = ({
|
||||
return undefined
|
||||
}, [dslUrl])
|
||||
|
||||
const controlHideCreateFromTemplatePanel = useContextSelector(AppListContext, ctx => ctx.controlHideCreateFromTemplatePanel)
|
||||
useEffect(() => {
|
||||
if (controlHideCreateFromTemplatePanel > 0)
|
||||
setShowNewAppTemplateDialog(false)
|
||||
}, [controlHideCreateFromTemplatePanel])
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
|
||||
@ -26,6 +26,10 @@
|
||||
@apply p-0.5 w-6 h-6 rounded-lg
|
||||
}
|
||||
|
||||
.action-btn-s {
|
||||
@apply w-5 h-5 rounded-[6px]
|
||||
}
|
||||
|
||||
.action-btn-xs {
|
||||
@apply p-0 w-4 h-4 rounded
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user