mirror of
https://github.com/langgenius/dify.git
synced 2026-01-22 21:15:24 +08:00
Compare commits
1 Commits
feat/pull-
...
feat/crede
| Author | SHA1 | Date | |
|---|---|---|---|
| fcb288c031 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -209,7 +209,6 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@ -71,8 +71,6 @@ def create_app() -> DifyApp:
|
||||
|
||||
|
||||
def initialize_extensions(app: DifyApp):
|
||||
# Initialize Flask context capture for workflow execution
|
||||
from context.flask_app_context import init_flask_context
|
||||
from extensions import (
|
||||
ext_app_metrics,
|
||||
ext_blueprints,
|
||||
@ -102,8 +100,6 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_warnings,
|
||||
)
|
||||
|
||||
init_flask_context()
|
||||
|
||||
extensions = [
|
||||
ext_timezone,
|
||||
ext_logging,
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
Core Context - Framework-agnostic context management.
|
||||
|
||||
This module provides context management that is independent of any specific
|
||||
web framework. Framework-specific implementations register their context
|
||||
capture functions at application initialization time.
|
||||
|
||||
This ensures the workflow layer remains completely decoupled from Flask
|
||||
or any other web framework.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Callable
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
)
|
||||
|
||||
# Global capturer function - set by framework-specific modules
|
||||
_capturer: Callable[[], IExecutionContext] | None = None
|
||||
|
||||
|
||||
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
|
||||
"""
|
||||
Register a context capture function.
|
||||
|
||||
This should be called by framework-specific modules (e.g., Flask)
|
||||
during application initialization.
|
||||
|
||||
Args:
|
||||
capturer: Function that captures current context and returns IExecutionContext
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = capturer
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context.
|
||||
|
||||
This function uses the registered context capturer. If no capturer
|
||||
is registered, it returns a minimal context with only contextvars
|
||||
(suitable for non-framework environments like tests or standalone scripts).
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
if _capturer is None:
|
||||
# No framework registered - return minimal context
|
||||
return ExecutionContext(
|
||||
app_context=NullAppContext(),
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
return _capturer()
|
||||
|
||||
|
||||
def reset_context_provider() -> None:
|
||||
"""
|
||||
Reset the context capturer.
|
||||
|
||||
This is primarily useful for testing to ensure a clean state.
|
||||
"""
|
||||
global _capturer
|
||||
_capturer = None
|
||||
|
||||
|
||||
__all__ = [
|
||||
"capture_current_context",
|
||||
"register_context_capturer",
|
||||
"reset_context_provider",
|
||||
]
|
||||
@ -1,198 +0,0 @@
|
||||
"""
|
||||
Flask App Context - Flask implementation of AppContext interface.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, final
|
||||
|
||||
from flask import Flask, current_app, g
|
||||
|
||||
from context import register_context_capturer
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
IExecutionContext,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskAppContext(AppContext):
|
||||
"""
|
||||
Flask implementation of AppContext.
|
||||
|
||||
This adapts Flask's app context to the AppContext interface.
|
||||
"""
|
||||
|
||||
def __init__(self, flask_app: Flask) -> None:
|
||||
"""
|
||||
Initialize Flask app context.
|
||||
|
||||
Args:
|
||||
flask_app: The Flask application instance
|
||||
"""
|
||||
self._flask_app = flask_app
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value from Flask app config."""
|
||||
return self._flask_app.config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name."""
|
||||
return self._flask_app.extensions.get(name)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask app context."""
|
||||
with self._flask_app.app_context():
|
||||
yield
|
||||
|
||||
@property
|
||||
def flask_app(self) -> Flask:
|
||||
"""Get the underlying Flask app instance."""
|
||||
return self._flask_app
|
||||
|
||||
|
||||
def capture_flask_context(user: Any = None) -> IExecutionContext:
|
||||
"""
|
||||
Capture current Flask execution context.
|
||||
|
||||
This function captures the Flask app context and contextvars from the
|
||||
current environment. It should be called from within a Flask request or
|
||||
app context.
|
||||
|
||||
Args:
|
||||
user: Optional user object to include in context
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured Flask context
|
||||
|
||||
Raises:
|
||||
RuntimeError: If called outside Flask context
|
||||
"""
|
||||
# Get Flask app instance
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
# Save current user if available
|
||||
saved_user = user
|
||||
if saved_user is None:
|
||||
# Check for user in g (flask-login)
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Capture contextvars
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
return FlaskExecutionContext(
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
user=saved_user,
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class FlaskExecutionContext:
|
||||
"""
|
||||
Flask-specific execution context.
|
||||
|
||||
This is a specialized version of ExecutionContext that includes Flask app
|
||||
context. It provides the same interface as ExecutionContext but with
|
||||
Flask-specific implementation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize Flask execution context.
|
||||
|
||||
Args:
|
||||
flask_app: Flask application instance
|
||||
context_vars: Python contextvars
|
||||
user: Optional user object
|
||||
"""
|
||||
self._app_context = FlaskAppContext(flask_app)
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
self._flask_app = flask_app
|
||||
|
||||
@property
|
||||
def app_context(self) -> FlaskAppContext:
|
||||
"""Get Flask app context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
def __enter__(self) -> "FlaskExecutionContext":
|
||||
"""Enter the Flask execution context."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
self._cm = self._app_context.enter()
|
||||
self._cm.__enter__()
|
||||
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the Flask execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter Flask execution context as context manager."""
|
||||
# Restore context variables
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user from g if available
|
||||
saved_user = None
|
||||
if hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with self._flask_app.app_context():
|
||||
# Restore user in new app context
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
yield
|
||||
|
||||
|
||||
def init_flask_context() -> None:
|
||||
"""
|
||||
Initialize Flask context capture by registering the capturer.
|
||||
|
||||
This function should be called during Flask application initialization
|
||||
to register the Flask-specific context capturer with the core context module.
|
||||
|
||||
Example:
|
||||
app = Flask(__name__)
|
||||
init_flask_context() # Register Flask context capturer
|
||||
|
||||
Note:
|
||||
This function does not need the app instance as it uses Flask's
|
||||
`current_app` to get the app when capturing context.
|
||||
"""
|
||||
register_context_capturer(capture_flask_context)
|
||||
@ -55,35 +55,6 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
model_config_data: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -93,8 +64,6 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ContextGeneratePayload)
|
||||
reg(SuggestedQuestionsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@ -309,74 +278,3 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/context-generate")
|
||||
class ContextGenerateApi(Resource):
|
||||
@console_ns.doc("generate_with_context")
|
||||
@console_ns.doc(description="Generate with multi-turn conversation context")
|
||||
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Content generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
from core.llm_generator.utils import deserialize_prompt_messages
|
||||
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/context-generate/suggested-questions")
|
||||
class SuggestedQuestionsApi(Resource):
|
||||
@console_ns.doc("generate_suggested_questions")
|
||||
@console_ns.doc(description="Generate suggested questions for context generation")
|
||||
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
|
||||
@console_ns.response(200, "Questions generated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
@ -46,8 +46,6 @@ from models.workflow import Workflow
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
from services.workflow.entities import MentionGraphRequest, MentionParameterSchema
|
||||
from services.workflow.mention_graph_service import MentionGraphService
|
||||
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -190,15 +188,6 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
|
||||
node_ids: list[str]
|
||||
|
||||
|
||||
class MentionGraphPayload(BaseModel):
|
||||
"""Request payload for generating mention graph."""
|
||||
|
||||
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
|
||||
parameter_key: str = Field(description="Key of the parameter being extracted")
|
||||
context_source: list[str] = Field(description="Variable selector for the context source")
|
||||
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@ -216,7 +205,6 @@ reg(WorkflowListQuery)
|
||||
reg(WorkflowUpdatePayload)
|
||||
reg(DraftWorkflowTriggerRunPayload)
|
||||
reg(DraftWorkflowTriggerRunAllPayload)
|
||||
reg(MentionGraphPayload)
|
||||
|
||||
|
||||
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
|
||||
@ -1178,54 +1166,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
|
||||
"status": "error",
|
||||
}
|
||||
), 400
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/mention-graph")
|
||||
class MentionGraphApi(Resource):
|
||||
"""
|
||||
API for generating Mention LLM node graph structures.
|
||||
|
||||
This endpoint creates a complete graph structure containing an LLM node
|
||||
configured to extract values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
@console_ns.doc("generate_mention_graph")
|
||||
@console_ns.doc(description="Generate a Mention LLM node graph structure")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[MentionGraphPayload.__name__])
|
||||
@console_ns.response(200, "Mention graph generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters")
|
||||
@console_ns.response(403, "Permission denied")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
Generate a Mention LLM node graph structure.
|
||||
|
||||
Returns a complete graph structure containing a single LLM node
|
||||
configured for extracting values from list[PromptMessage] context.
|
||||
"""
|
||||
|
||||
payload = MentionGraphPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
parameter_schema = MentionParameterSchema(
|
||||
name=payload.parameter_schema.get("name", payload.parameter_key),
|
||||
type=payload.parameter_schema.get("type", "string"),
|
||||
description=payload.parameter_schema.get("description", ""),
|
||||
)
|
||||
|
||||
request = MentionGraphRequest(
|
||||
parent_node_id=payload.parent_node_id,
|
||||
parameter_key=payload.parameter_key,
|
||||
context_source=payload.context_source,
|
||||
parameter_schema=parameter_schema,
|
||||
)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
service = MentionGraphService(session)
|
||||
response = service.generate_mention_graph(tenant_id=app_model.tenant_id, request=request)
|
||||
|
||||
return response.model_dump()
|
||||
|
||||
@ -17,7 +17,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@ -58,8 +58,6 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, ArrayPromptMessageSegment):
|
||||
return value.to_object()
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
|
||||
@ -69,13 +69,6 @@ class ActivateCheckApi(Resource):
|
||||
if invitation:
|
||||
data = invitation.get("data", {})
|
||||
tenant = invitation.get("tenant", None)
|
||||
|
||||
# Check workspace permission
|
||||
if tenant:
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
workspace_name = tenant.name if tenant else None
|
||||
workspace_id = tenant.id if tenant else None
|
||||
invitee_email = data.get("email") if data else None
|
||||
|
||||
@ -107,12 +107,6 @@ class MemberInviteEmailApi(Resource):
|
||||
inviter = current_user
|
||||
if not inviter.current_tenant:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(inviter.current_tenant.id)
|
||||
|
||||
invitation_results = []
|
||||
console_web_url = dify_config.CONSOLE_WEB_URL
|
||||
|
||||
|
||||
@ -20,7 +20,6 @@ from controllers.console.error import AccountNotLinkTenantError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
only_edition_enterprise,
|
||||
setup_required,
|
||||
)
|
||||
from enums.cloud_plan import CloudPlan
|
||||
@ -29,7 +28,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
from services.workspace_service import WorkspaceService
|
||||
@ -290,31 +288,3 @@ class WorkspaceInfoApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/permission")
|
||||
class WorkspacePermissionApi(Resource):
|
||||
"""Get workspace permissions for the current workspace."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_enterprise
|
||||
def get(self):
|
||||
"""
|
||||
Get workspace permission settings.
|
||||
Returns permission flags that control workspace features like member invitations and owner transfer.
|
||||
"""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if not current_tenant_id:
|
||||
raise ValueError("No current tenant")
|
||||
|
||||
# Get workspace permissions from enterprise service
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
|
||||
|
||||
return {
|
||||
"workspace_id": permission.workspace_id,
|
||||
"allow_member_invite": permission.allow_member_invite,
|
||||
"allow_owner_transfer": permission.allow_owner_transfer,
|
||||
}, 200
|
||||
|
||||
@ -286,12 +286,13 @@ def enable_change_email(view: Callable[P, R]):
|
||||
def is_allow_transfer_owner(view: Callable[P, R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
from libs.workspace_permission import check_workspace_owner_transfer_permission
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
# Check both billing/plan level and workspace policy level permissions
|
||||
check_workspace_owner_transfer_permission(current_tenant_id)
|
||||
return view(*args, **kwargs)
|
||||
features = FeatureService.get_features(current_tenant_id)
|
||||
if features.is_allow_transfer_workspace:
|
||||
return view(*args, **kwargs)
|
||||
|
||||
# otherwise, return 403
|
||||
abort(403)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -70,8 +70,6 @@ class _NodeSnapshot:
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
mention_parent_id: str = ""
|
||||
"""Empty string means the node is not an extractor node."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@ -133,7 +131,6 @@ class WorkflowResponseConverter:
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
mention_parent_id=event.in_mention_parent_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
@ -290,7 +287,6 @@ class WorkflowResponseConverter:
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
@ -377,7 +373,6 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
),
|
||||
)
|
||||
|
||||
@ -427,7 +422,6 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
yield response_chunk
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
import contexts
|
||||
from configs import dify_config
|
||||
@ -23,7 +23,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
@ -477,7 +476,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:return:
|
||||
"""
|
||||
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||
with session_factory.create_session() as session:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
workflow = session.scalar(
|
||||
select(Workflow).where(
|
||||
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
|
||||
@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
yield response_chunk
|
||||
|
||||
@ -385,7 +385,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@ -406,7 +405,6 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
@ -430,7 +428,6 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@ -447,7 +444,6 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@ -464,7 +460,6 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
@ -474,7 +469,6 @@ class WorkflowBasedAppRunner:
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
@ -483,7 +477,6 @@ class WorkflowBasedAppRunner:
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
|
||||
@ -190,8 +190,6 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@ -231,8 +229,6 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@ -310,8 +306,6 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
@ -334,8 +328,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -391,8 +383,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@ -417,8 +407,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@ -262,7 +262,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@ -286,7 +285,6 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -322,7 +320,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@ -352,7 +349,6 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@ -388,7 +384,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@ -419,7 +414,6 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
@ -1,5 +1,4 @@
|
||||
import base64
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
@ -11,10 +10,7 @@ from core.model_runtime.entities import (
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -22,8 +18,6 @@ from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
@ -95,8 +89,6 @@ def to_prompt_message_content(
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
|
||||
"file_ref": _encode_file_ref(f),
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
@ -104,17 +96,6 @@ def to_prompt_message_content(
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def _encode_file_ref(f: File) -> str | None:
|
||||
"""Encode file reference as 'transfer_method:id_or_url' string."""
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return f"remote:{f.remote_url}" if f.remote_url else None
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return f"local:{f.related_id}" if f.related_id else None
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return f"tool:{f.related_id}" if f.related_id else None
|
||||
return None
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
@ -183,128 +164,3 @@ def _to_url(f: File, /):
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def restore_multimodal_content(
|
||||
content: MultiModalPromptMessageContent,
|
||||
) -> MultiModalPromptMessageContent:
|
||||
"""
|
||||
Restore base64_data or url for multimodal content from file_ref.
|
||||
|
||||
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
|
||||
|
||||
Args:
|
||||
content: MultiModalPromptMessageContent with file_ref field
|
||||
|
||||
Returns:
|
||||
MultiModalPromptMessageContent with restored base64_data or url
|
||||
"""
|
||||
# Skip if no file reference or content already has data
|
||||
if not content.file_ref:
|
||||
return content
|
||||
if content.base64_data or content.url:
|
||||
return content
|
||||
|
||||
try:
|
||||
file = _build_file_from_ref(
|
||||
file_ref=content.file_ref,
|
||||
file_format=content.format,
|
||||
mime_type=content.mime_type,
|
||||
filename=content.filename,
|
||||
)
|
||||
if not file:
|
||||
return content
|
||||
|
||||
# Restore content based on config
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
|
||||
restored_base64 = _get_encoded_string(file)
|
||||
return content.model_copy(update={"base64_data": restored_base64})
|
||||
else:
|
||||
restored_url = _to_url(file)
|
||||
return content.model_copy(update={"url": restored_url})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore multimodal content: %s", e)
|
||||
return content
|
||||
|
||||
|
||||
def _build_file_from_ref(
|
||||
file_ref: str,
|
||||
file_format: str | None,
|
||||
mime_type: str | None,
|
||||
filename: str | None,
|
||||
) -> File | None:
|
||||
"""
|
||||
Build a File object from encoded file_ref string.
|
||||
|
||||
Args:
|
||||
file_ref: Encoded reference "transfer_method:id_or_url"
|
||||
file_format: The file format/extension (without dot)
|
||||
mime_type: The mime type
|
||||
filename: The filename
|
||||
|
||||
Returns:
|
||||
File object with storage_key loaded, or None if not found
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
# Parse file_ref: "method:value"
|
||||
if ":" not in file_ref:
|
||||
logger.warning("Invalid file_ref format: %s", file_ref)
|
||||
return None
|
||||
|
||||
method, value = file_ref.split(":", 1)
|
||||
extension = f".{file_format}" if file_format else None
|
||||
|
||||
if method == "remote":
|
||||
return File(
|
||||
tenant_id="",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
filename=filename,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Query database for storage_key
|
||||
with Session(db.engine) as session:
|
||||
if method == "local":
|
||||
stmt = select(UploadFile).where(UploadFile.id == value)
|
||||
upload_file = session.scalar(stmt)
|
||||
if upload_file:
|
||||
return File(
|
||||
tenant_id=upload_file.tenant_id,
|
||||
type=FileType(upload_file.extension)
|
||||
if hasattr(FileType, upload_file.extension.upper())
|
||||
else FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=value,
|
||||
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
|
||||
mime_type=mime_type or upload_file.mime_type,
|
||||
filename=filename or upload_file.name,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
elif method == "tool":
|
||||
stmt = select(ToolFile).where(ToolFile.id == value)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file:
|
||||
return File(
|
||||
tenant_id=tool_file.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type or tool_file.mimetype,
|
||||
filename=filename or tool_file.name,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
logger.warning("File not found for file_ref: %s", file_ref)
|
||||
return None
|
||||
|
||||
@ -1,16 +1,11 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
from core.llm_generator.output_models import (
|
||||
CodeNodeStructuredOutput,
|
||||
InstructionModifyOutput,
|
||||
SuggestedQuestionsOutput,
|
||||
)
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@ -398,432 +393,6 @@ class LLMGenerator:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
- variables: Input variable selectors
|
||||
- code_language: Code language
|
||||
- code: Generated code
|
||||
- outputs: Output definitions
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
complete_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
*prompt_messages,
|
||||
]
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
|
||||
# Get model instance and schema
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return cls._error_response(f"Model schema not found for {model_name}")
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=complete_messages,
|
||||
output_model=CodeNodeStructuredOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
|
||||
return cls._error_response(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, error: str) -> dict:
|
||||
"""Return error response in CodeNodeData format."""
|
||||
return {
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "",
|
||||
"outputs": {},
|
||||
"message": "",
|
||||
"error": error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
model_manager = ModelManager()
|
||||
if model_config:
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = model_instance.model
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
try:
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=SuggestedQuestionsOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate suggested questions, model: %s", model_name)
|
||||
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
outputs = content.get("outputs", {"result": {"type": parameter_type}})
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
@ -960,10 +529,6 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
|
||||
provider=model_config.get("provider", ""),
|
||||
model=model_config.get("name", ""),
|
||||
)
|
||||
model_name = model_config.get("name", "")
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"error": f"Model schema not found for {model_name}"}
|
||||
match node_type:
|
||||
case "llm" | "agent":
|
||||
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
|
||||
@ -987,18 +552,20 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
|
||||
|
||||
response = invoke_llm_with_pydantic_model(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=list(prompt_messages),
|
||||
output_model=InstructionModifyOutput,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
return response.structured_output or {}
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
|
||||
@ -1,34 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(min_length=3, max_length=3)
|
||||
|
||||
|
||||
class CodeNodeOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: SegmentType
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelector]
|
||||
code: str
|
||||
outputs: dict[str, CodeNodeOutput]
|
||||
explanation: str
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str
|
||||
message: str
|
||||
@ -1,188 +0,0 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
@ -2,13 +2,12 @@ import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from enum import StrEnum
|
||||
from typing import Any, Literal, TypeVar, cast, overload
|
||||
from typing import Any, Literal, cast, overload
|
||||
|
||||
import json_repair
|
||||
from pydantic import BaseModel, TypeAdapter, ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@ -44,9 +43,6 @@ class SpecialModelType(StrEnum):
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseModel)
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@ -61,7 +57,6 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -77,7 +72,6 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@ -93,7 +87,6 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@ -108,30 +101,23 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output.
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
|
||||
This method invokes model_instance.invoke_llm with json_schema and parses
|
||||
the result as structured output.
|
||||
|
||||
:param provider: model provider name
|
||||
:param model_schema: model schema entity
|
||||
:param model_instance: model instance to invoke
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema for structured output
|
||||
:param json_schema: json schema
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
# handle native json schema
|
||||
model_parameters_with_json_schema: dict[str, Any] = {
|
||||
**(model_parameters or {}),
|
||||
@ -167,18 +153,8 @@ def invoke_llm_with_structured_output(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
@ -210,18 +186,8 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=structured_output,
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
@ -236,87 +202,6 @@ def invoke_llm_with_structured_output(
|
||||
return generator()
|
||||
|
||||
|
||||
@overload
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: Literal[False] = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
|
||||
|
||||
def invoke_llm_with_pydantic_model(
|
||||
*,
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
output_model: type[T],
|
||||
model_parameters: Mapping | None = None,
|
||||
tools: Sequence[PromptMessageTool] | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput:
|
||||
"""
|
||||
Invoke large language model with a Pydantic output model.
|
||||
|
||||
This helper generates a JSON schema from the Pydantic model, invokes the
|
||||
structured-output LLM path, and validates the result in non-streaming mode.
|
||||
"""
|
||||
if stream:
|
||||
raise ValueError("invoke_llm_with_pydantic_model only supports stream=False")
|
||||
|
||||
json_schema = _schema_from_pydantic(output_model)
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
tools=tools,
|
||||
stop=stop,
|
||||
stream=False,
|
||||
user=user,
|
||||
callbacks=callbacks,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
structured_output = result.structured_output
|
||||
if structured_output is None:
|
||||
raise OutputParserError("Structured output is empty")
|
||||
|
||||
validated_output = _validate_structured_output(output_model, structured_output)
|
||||
return result.model_copy(update={"structured_output": validated_output})
|
||||
|
||||
|
||||
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
|
||||
return output_model.model_json_schema()
|
||||
|
||||
|
||||
def _validate_structured_output(
|
||||
output_model: type[T],
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
try:
|
||||
validated_output = output_model.model_validate(structured_output)
|
||||
except ValidationError as exc:
|
||||
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
|
||||
return validated_output.model_dump(mode="python")
|
||||
|
||||
|
||||
def _handle_native_json_schema(
|
||||
provider: str,
|
||||
model_schema: AIModelEntity,
|
||||
|
||||
@ -1,45 +0,0 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
@ -1,267 +0,0 @@
|
||||
# Memory Module
|
||||
|
||||
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
|
||||
|
||||
## Overview
|
||||
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory Architecture │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ TokenBufferMemory │ │
|
||||
│ │ Scope: Conversation │ │
|
||||
│ │ Storage: Database (Message table) │ │
|
||||
│ │ Key: conversation_id │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
|
||||
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TokenBufferMemory (Existing)
|
||||
|
||||
### Purpose
|
||||
|
||||
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Conversation-scoped**: All messages within a conversation are candidates
|
||||
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
|
||||
- **Token-limited**: Truncates history to fit within `max_token_limit`
|
||||
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Message Table TokenBufferMemory LLM
|
||||
│ │ │
|
||||
│ SELECT * FROM messages │ │
|
||||
│ WHERE conversation_id = ? │ │
|
||||
│ ORDER BY created_at DESC │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ extract_thread_messages() │
|
||||
│ │ │
|
||||
│ build_prompt_message_with_files() │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage]
|
||||
│ ├───────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Extraction
|
||||
|
||||
When a user regenerates a response, a new thread is created:
|
||||
|
||||
```
|
||||
Message A (user)
|
||||
└── Message A' (assistant)
|
||||
└── Message B (user)
|
||||
└── Message B' (assistant)
|
||||
└── Message A'' (assistant, regenerated) ← New thread
|
||||
└── Message C (user)
|
||||
└── Message C' (assistant)
|
||||
```
|
||||
|
||||
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory
|
||||
|
||||
### Purpose
|
||||
|
||||
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
### Use Cases
|
||||
|
||||
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design: Zero Extra Storage
|
||||
|
||||
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
|
||||
|
||||
Each LLM node execution outputs:
|
||||
```python
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
This `outputs["context"]` contains:
|
||||
- All previous user/assistant messages (excluding system prompt)
|
||||
- The current assistant response
|
||||
|
||||
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
|
||||
|
||||
### Benefits
|
||||
|
||||
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|
||||
|--------|----------------------------|--------------------------------|
|
||||
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
|
||||
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
|
||||
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
|
||||
| Migration | Required | None |
|
||||
| Complexity | High | Low |
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ │ │
|
||||
│ SELECT outputs FROM │ │
|
||||
│ workflow_node_executions │ │
|
||||
│ WHERE workflow_run_id = ? │ │
|
||||
│ AND node_id = ? │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ outputs["context"] │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ deserialize PromptMessages │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Tracking
|
||||
|
||||
Thread extraction still uses `Message` table's `parent_message_id` structure:
|
||||
|
||||
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
|
||||
2. Get the last completed `workflow_run_id` in the thread
|
||||
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
|
||||
|
||||
### API
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""Initialize node-level memory."""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Reads from last completed execution's outputs["context"].
|
||||
"""
|
||||
...
|
||||
|
||||
# Legacy methods (no-op, kept for compatibility)
|
||||
def add_messages(self, *args, **kwargs) -> None: pass
|
||||
def flush(self) -> None: pass
|
||||
def clear(self) -> None: pass
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
**Mode Behavior:**
|
||||
|
||||
| Mode | Memory Class | Scope | Availability |
|
||||
| -------------- | --------------------- | ------------------------ | ------------- |
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ---------------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Extending to Other Nodes
|
||||
|
||||
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
|
||||
|
||||
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
|
||||
2. The `NodeTokenBufferMemory` will automatically pick it up
|
||||
|
||||
Nodes that could potentially support this:
|
||||
- `question_classifier`
|
||||
- `parameter_extractor`
|
||||
- `agent`
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
|
||||
2. **Compression**: For very long conversations, consider summarization strategies
|
||||
3. **Extension**: Other nodes may benefit from node-level memory
|
||||
@ -1,11 +0,0 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
@ -1,83 +0,0 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
@ -1,197 +0,0 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
class TokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@ -91,9 +91,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
|
||||
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
@ -279,5 +276,7 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: BaseMemory,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
@ -6,13 +5,6 @@ from pydantic import BaseModel
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class MemoryMode(StrEnum):
|
||||
"""Memory mode for LLM nodes."""
|
||||
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: BaseMemory,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@ -52,7 +52,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: BaseMemory,
|
||||
memory: TokenBufferMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
@ -73,7 +73,7 @@ class PromptTransform:
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
|
||||
@ -1047,8 +1047,6 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
@ -1058,11 +1056,6 @@ class ToolManager:
|
||||
elif tool_input.type == "mixed":
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type not supported in agent mode
|
||||
raise ToolParameterError(
|
||||
f"Mention type not supported in agent for parameter '{parameter.name}'"
|
||||
)
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
@ -5,6 +5,7 @@ import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import has_request_context
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.db.session_factory import session_factory
|
||||
@ -28,21 +29,6 @@ from models.workflow import Workflow
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _try_resolve_user_from_request() -> Account | EndUser | None:
|
||||
"""
|
||||
Try to resolve user from Flask request context.
|
||||
|
||||
Returns None if not in a request context or if user is not available.
|
||||
"""
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
# Use _get_current_object() to dereference the proxy
|
||||
user = getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
# Check if we got a valid user object
|
||||
if user is not None and hasattr(user, "id"):
|
||||
return user
|
||||
return None
|
||||
|
||||
|
||||
class WorkflowTool(Tool):
|
||||
"""
|
||||
Workflow tool.
|
||||
@ -223,13 +209,21 @@ class WorkflowTool(Tool):
|
||||
Returns:
|
||||
Account | EndUser | None: The resolved user object, or None if resolution fails.
|
||||
"""
|
||||
# Try to resolve user from request context first
|
||||
user = _try_resolve_user_from_request()
|
||||
if user is not None:
|
||||
return user
|
||||
if has_request_context():
|
||||
return self._resolve_user_from_request()
|
||||
else:
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
|
||||
# Fall back to database resolution
|
||||
return self._resolve_user_from_database(user_id=user_id)
|
||||
def _resolve_user_from_request(self) -> Account | EndUser | None:
|
||||
"""
|
||||
Resolve user from Flask request context.
|
||||
"""
|
||||
try:
|
||||
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
|
||||
return getattr(current_user, "_get_current_object", lambda: current_user)()
|
||||
except Exception as e:
|
||||
logger.warning("Failed to resolve user from request context: %s", e)
|
||||
return None
|
||||
|
||||
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
|
||||
"""
|
||||
|
||||
@ -4,7 +4,6 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
@ -21,7 +20,6 @@ from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
@ -44,8 +42,6 @@ __all__ = [
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayObjectVariable",
|
||||
"ArrayPromptMessageSegment",
|
||||
"ArrayPromptMessageVariable",
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
|
||||
@ -6,7 +6,6 @@ from typing import Annotated, Any, TypeAlias
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
@ -209,15 +208,6 @@ class ArrayBooleanSegment(ArraySegment):
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
class ArrayPromptMessageSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
|
||||
value: Sequence[PromptMessage]
|
||||
|
||||
def to_object(self):
|
||||
"""Convert to JSON-serializable format for database storage and frontend."""
|
||||
return [msg.model_dump() for msg in self.value]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
@ -258,7 +248,6 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@ -45,7 +45,6 @@ class SegmentType(StrEnum):
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_PROMPT_MESSAGE = "array[message]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@ -3,10 +3,8 @@ from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
@ -18,7 +16,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
||||
|
||||
def segment_orjson_default(o: Any):
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
@ -26,8 +24,6 @@ def segment_orjson_default(o: Any):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, PromptMessage):
|
||||
return o.model_dump()
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
|
||||
@ -12,7 +12,6 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@ -111,10 +110,6 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
@ -165,7 +160,6 @@ Variable: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
"""
|
||||
Execution Context - Context management for workflow execution.
|
||||
|
||||
This package provides Flask-independent context management for workflow
|
||||
execution in multi-threaded environments.
|
||||
"""
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
capture_current_context,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AppContext",
|
||||
"ExecutionContext",
|
||||
"IExecutionContext",
|
||||
"NullAppContext",
|
||||
"capture_current_context",
|
||||
]
|
||||
@ -1,216 +0,0 @@
|
||||
"""
|
||||
Execution Context - Abstracted context management for workflow execution.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from typing import Any, Protocol, final, runtime_checkable
|
||||
|
||||
|
||||
class AppContext(ABC):
|
||||
"""
|
||||
Abstract application context interface.
|
||||
|
||||
This abstraction allows workflow execution to work with or without Flask
|
||||
by providing a common interface for application context management.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get Flask extension by name (e.g., 'db', 'cache')."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def enter(self) -> AbstractContextManager[None]:
|
||||
"""Enter the application context."""
|
||||
pass
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class IExecutionContext(Protocol):
|
||||
"""
|
||||
Protocol for execution context.
|
||||
|
||||
This protocol defines the interface that all execution contexts must implement,
|
||||
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
|
||||
"""
|
||||
|
||||
def __enter__(self) -> "IExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
...
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
...
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
...
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionContext:
|
||||
"""
|
||||
Execution context for workflow execution in worker threads.
|
||||
|
||||
This class encapsulates all context needed for workflow execution:
|
||||
- Application context (Flask app or standalone)
|
||||
- Context variables for Python contextvars
|
||||
- User information (optional)
|
||||
|
||||
It is designed to be serializable and passable to worker threads.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_context: AppContext | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
user: Any = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize execution context.
|
||||
|
||||
Args:
|
||||
app_context: Application context (Flask or standalone)
|
||||
context_vars: Python contextvars to preserve
|
||||
user: User object (optional)
|
||||
"""
|
||||
self._app_context = app_context
|
||||
self._context_vars = context_vars
|
||||
self._user = user
|
||||
|
||||
@property
|
||||
def app_context(self) -> AppContext | None:
|
||||
"""Get application context."""
|
||||
return self._app_context
|
||||
|
||||
@property
|
||||
def context_vars(self) -> contextvars.Context | None:
|
||||
"""Get context variables."""
|
||||
return self._context_vars
|
||||
|
||||
@property
|
||||
def user(self) -> Any:
|
||||
"""Get user object."""
|
||||
return self._user
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""
|
||||
Enter this execution context.
|
||||
|
||||
This is a convenience method that creates a context manager.
|
||||
"""
|
||||
# Restore context variables if provided
|
||||
if self._context_vars:
|
||||
for var, val in self._context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Enter app context if available
|
||||
if self._app_context is not None:
|
||||
with self._app_context.enter():
|
||||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
def __enter__(self) -> "ExecutionContext":
|
||||
"""Enter the execution context."""
|
||||
self._cm = self.enter()
|
||||
self._cm.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, *args: Any) -> None:
|
||||
"""Exit the execution context."""
|
||||
if hasattr(self, "_cm"):
|
||||
self._cm.__exit__(*args)
|
||||
|
||||
|
||||
class NullAppContext(AppContext):
|
||||
"""
|
||||
Null implementation of AppContext for non-Flask environments.
|
||||
|
||||
This is used when running without Flask (e.g., in tests or standalone mode).
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict[str, Any] | None = None) -> None:
|
||||
"""
|
||||
Initialize null app context.
|
||||
|
||||
Args:
|
||||
config: Optional configuration dictionary
|
||||
"""
|
||||
self._config = config or {}
|
||||
self._extensions: dict[str, Any] = {}
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
"""Get configuration value by key."""
|
||||
return self._config.get(key, default)
|
||||
|
||||
def get_extension(self, name: str) -> Any:
|
||||
"""Get extension by name."""
|
||||
return self._extensions.get(name)
|
||||
|
||||
def set_extension(self, name: str, extension: Any) -> None:
|
||||
"""Set extension by name."""
|
||||
self._extensions[name] = extension
|
||||
|
||||
@contextmanager
|
||||
def enter(self) -> Generator[None, None, None]:
|
||||
"""Enter null context (no-op)."""
|
||||
yield
|
||||
|
||||
|
||||
class ExecutionContextBuilder:
|
||||
"""
|
||||
Builder for creating ExecutionContext instances.
|
||||
|
||||
This provides a fluent API for building execution contexts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._app_context: AppContext | None = None
|
||||
self._context_vars: contextvars.Context | None = None
|
||||
self._user: Any = None
|
||||
|
||||
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
|
||||
"""Set application context."""
|
||||
self._app_context = app_context
|
||||
return self
|
||||
|
||||
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
|
||||
"""Set context variables."""
|
||||
self._context_vars = context_vars
|
||||
return self
|
||||
|
||||
def with_user(self, user: Any) -> "ExecutionContextBuilder":
|
||||
"""Set user."""
|
||||
self._user = user
|
||||
return self
|
||||
|
||||
def build(self) -> ExecutionContext:
|
||||
"""Build the execution context."""
|
||||
return ExecutionContext(
|
||||
app_context=self._app_context,
|
||||
context_vars=self._context_vars,
|
||||
user=self._user,
|
||||
)
|
||||
|
||||
|
||||
def capture_current_context() -> IExecutionContext:
|
||||
"""
|
||||
Capture current execution context from the calling environment.
|
||||
|
||||
Returns:
|
||||
IExecutionContext with captured context
|
||||
"""
|
||||
from context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
File diff suppressed because it is too large
Load Diff
@ -63,7 +63,6 @@ class NodeType(StrEnum):
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
GROUP = "group"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
@ -253,7 +252,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@ -307,14 +307,7 @@ class Graph:
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
# Filter out UI-only node types:
|
||||
# - custom-note: top-level type (node_config.type == "custom-note")
|
||||
# - group: data-level type (node_config.data.type == "group")
|
||||
node_configs = [
|
||||
node_config
|
||||
for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
|
||||
]
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
@ -93,8 +93,8 @@ class EventHandler:
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops, iterations, or extractor groups are always collected
|
||||
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
@ -125,11 +125,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_started(event)
|
||||
return
|
||||
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
@ -169,11 +164,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_success(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
@ -236,11 +226,6 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_failed(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
@ -360,57 +345,3 @@ class EventHandler:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
def _is_extractor_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id represents an extractor node (has parent_node_id).
|
||||
|
||||
Extractor nodes extract values from list[PromptMessage] for their parent node.
|
||||
They have a parent_node_id field pointing to their parent node.
|
||||
"""
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
return False
|
||||
return node.node_data.is_extractor_node
|
||||
|
||||
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node started event.
|
||||
|
||||
Extractor nodes don't need full execution tracking, just collect the event.
|
||||
"""
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle extractor node success event.
|
||||
|
||||
Extractor nodes need special handling:
|
||||
- Store outputs in variable pool (for reference by other nodes)
|
||||
- Accumulate token usage
|
||||
- Collect the event for logging
|
||||
- Do NOT process edges or enqueue next nodes (parent node handles that)
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node failed event.
|
||||
|
||||
Extractor node failures are collected for logging,
|
||||
but the parent node is responsible for handling the error.
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Collect the event for logging
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@ -7,13 +7,15 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
@ -157,8 +159,17 @@ class GraphEngine:
|
||||
self._layers: list[GraphEngineLayer] = []
|
||||
|
||||
# === Worker Pool Setup ===
|
||||
# Capture execution context for worker threads
|
||||
execution_context = capture_current_context()
|
||||
# Capture Flask app context for worker threads
|
||||
flask_app: Flask | None = None
|
||||
try:
|
||||
app = current_app._get_current_object() # type: ignore
|
||||
if isinstance(app, Flask):
|
||||
flask_app = app
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Capture context variables for worker threads
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Create worker pool for parallel node execution
|
||||
self._worker_pool = WorkerPool(
|
||||
@ -166,7 +177,8 @@ class GraphEngine:
|
||||
event_queue=self._event_queue,
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
flask_app=flask_app,
|
||||
context_vars=context_vars,
|
||||
min_workers=self._min_workers,
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
|
||||
@ -68,7 +68,6 @@ class _NodeRuntimeSnapshot:
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
mention_parent_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@ -231,7 +230,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
@ -258,7 +256,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
@ -5,27 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
|
||||
to the event_queue for the dispatcher to process.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, final
|
||||
from typing import final
|
||||
from uuid import uuid4
|
||||
|
||||
from flask import Flask
|
||||
from typing_extensions import override
|
||||
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
@final
|
||||
class Worker(threading.Thread):
|
||||
@ -45,7 +44,8 @@ class Worker(threading.Thread):
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize worker thread.
|
||||
@ -56,17 +56,19 @@ class Worker(threading.Thread):
|
||||
graph: Graph containing nodes to execute
|
||||
layers: Graph engine layers for node execution hooks
|
||||
worker_id: Unique identifier for this worker
|
||||
execution_context: Optional execution context for context preservation
|
||||
flask_app: Optional Flask application for context preservation
|
||||
context_vars: Optional context variables to preserve in worker thread
|
||||
"""
|
||||
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
@ -133,9 +135,11 @@ class Worker(threading.Thread):
|
||||
|
||||
error: Exception | None = None
|
||||
|
||||
# Execute the node with preserved context if execution context is provided
|
||||
if self._execution_context is not None:
|
||||
with self._execution_context:
|
||||
if self._flask_app and self._context_vars:
|
||||
with preserve_flask_contexts(
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
):
|
||||
self._invoke_node_run_start_hooks(node)
|
||||
try:
|
||||
node_events = node.run()
|
||||
|
||||
@ -8,10 +8,9 @@ DynamicScaler, and WorkerFactory into a single class.
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from typing import final
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import GraphNodeEventBase
|
||||
|
||||
@ -21,6 +20,11 @@ from ..worker import Worker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from contextvars import Context
|
||||
|
||||
from flask import Flask
|
||||
|
||||
|
||||
@final
|
||||
class WorkerPool:
|
||||
@ -38,7 +42,8 @@ class WorkerPool:
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
max_workers: int | None = None,
|
||||
scale_up_threshold: int | None = None,
|
||||
@ -52,7 +57,8 @@ class WorkerPool:
|
||||
event_queue: Queue for worker events
|
||||
graph: The workflow graph
|
||||
layers: Graph engine layers for node execution hooks
|
||||
execution_context: Optional execution context for context preservation
|
||||
flask_app: Optional Flask app for context preservation
|
||||
context_vars: Optional context variables
|
||||
min_workers: Minimum number of workers
|
||||
max_workers: Maximum number of workers
|
||||
scale_up_threshold: Queue depth to trigger scale up
|
||||
@ -61,7 +67,8 @@ class WorkerPool:
|
||||
self._ready_queue = ready_queue
|
||||
self._event_queue = event_queue
|
||||
self._graph = graph
|
||||
self._execution_context = execution_context
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._layers = layers
|
||||
|
||||
# Scaling parameters with defaults
|
||||
@ -145,7 +152,8 @@ class WorkerPool:
|
||||
graph=self._graph,
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
|
||||
@ -21,12 +21,6 @@ class GraphNodeEventBase(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""Parent node id if this is an extractor node event.
|
||||
|
||||
When set, indicates this event belongs to an extractor node that
|
||||
is extracting values for the specified parent node.
|
||||
"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
@ -12,20 +12,11 @@ from sqlalchemy.orm import Session
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@ -145,9 +136,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
@ -161,7 +149,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
@ -408,20 +395,8 @@ class AgentNode(Node[AgentNodeData]):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@ -429,26 +404,16 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == self.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
@ -492,136 +457,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
# For node memory, we need a model_instance for token counting
|
||||
# Use a simple default model for this purpose
|
||||
try:
|
||||
model_instance = ModelManager().get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@ -632,7 +467,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@ -877,12 +711,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@ -891,7 +719,6 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@ -1,10 +1,4 @@
|
||||
from .entities import (
|
||||
BaseIterationNodeData,
|
||||
BaseIterationState,
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
)
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -175,16 +175,6 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Parent node ID when this node is used as an extractor.
|
||||
# If set, this node is an "attached" extractor node that extracts values
|
||||
# from list[PromptMessage] for the parent node's parameters.
|
||||
parent_node_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_extractor_node(self) -> bool:
|
||||
"""Check if this node is an extractor node (has parent_node_id)."""
|
||||
return self.parent_node_id is not None
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
|
||||
@ -270,87 +270,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
|
||||
Returns:
|
||||
List of node configuration dicts for extractor nodes
|
||||
"""
|
||||
nodes = self.graph_config.get("nodes", [])
|
||||
extractor_configs = []
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
if node_data.get("parent_node_id") == self._node_id:
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
Extractor nodes are nodes with parent_node_id == self._node_id.
|
||||
They are executed before the main node to extract values from list[PromptMessage].
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
extractor_configs = self._find_extractor_node_configs()
|
||||
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
for config in extractor_configs:
|
||||
node_id = config.get("id")
|
||||
node_data = config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_id or not node_type_str:
|
||||
continue
|
||||
|
||||
# Get node class
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
continue
|
||||
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Instantiate and execute the extractor node
|
||||
extractor_node = node_cls(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
# Tag event with parent node id for stream ordering and history tracking
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
event.in_mention_parent_id = self._node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs: Mapping[str, Any] = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
if not isinstance(event, NodeRunStreamChunkEvent):
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_mention_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
import contextvars
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, NewType, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -37,6 +39,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
from .exc import (
|
||||
InvalidIteratorValueError,
|
||||
@ -48,7 +51,6 @@ from .exc import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.context import IExecutionContext
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -250,7 +252,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self._execute_single_iteration_parallel,
|
||||
index=index,
|
||||
item=item,
|
||||
execution_context=self._capture_execution_context(),
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
future_to_index[future] = index
|
||||
|
||||
@ -303,10 +306,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
self,
|
||||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
events: list[GraphNodeEventBase] = []
|
||||
outputs_temp: list[object] = []
|
||||
@ -335,12 +339,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
graph_engine.graph_runtime_state.llm_usage,
|
||||
)
|
||||
|
||||
def _capture_execution_context(self) -> "IExecutionContext":
|
||||
"""Capture current execution context for parallel iterations."""
|
||||
from core.workflow.context import capture_current_context
|
||||
|
||||
return capture_current_context()
|
||||
|
||||
def _handle_iteration_success(
|
||||
self,
|
||||
started_at: datetime,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@ -58,28 +58,9 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class PromptMessageContext(BaseModel):
|
||||
"""Context variable reference in prompt template.
|
||||
|
||||
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
|
||||
This will be expanded to list[PromptMessage] at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
value_selector: Sequence[str] = Field(alias="$context")
|
||||
|
||||
|
||||
# Union type for prompt template items (static message or context variable reference)
|
||||
PromptTemplateItem: TypeAlias = Annotated[
|
||||
LLMNodeChatModelMessage | PromptMessageContext,
|
||||
Field(discriminator=None),
|
||||
]
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
|
||||
@ -8,20 +8,12 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
@ -94,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
node_id: str = "",
|
||||
) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param app_id: Application ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param node_data_memory: Memory configuration
|
||||
:param model_instance: Model instance for token counting
|
||||
:param node_id: Node ID in the workflow (required for node mode)
|
||||
:return: Memory instance or None if not applicable
|
||||
"""
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id from variable pool (required for both modes in Chatflow)
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data_memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
if not node_id:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
@ -209,87 +170,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
|
||||
If file_ref is present, clears base64_data and url (they can be restored later).
|
||||
Otherwise, truncates base64_data as fallback for legacy data.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, handling multi-modal data based on file_ref availability
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
if item.file_ref:
|
||||
# Clear base64 and url, keep file_ref for later restoration
|
||||
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
|
||||
else:
|
||||
# Fallback: truncate base64_data if no file_ref (legacy data)
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
|
||||
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) in a list of PromptMessages.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.
|
||||
|
||||
Args:
|
||||
messages: List of PromptMessages that may contain truncated multimodal content
|
||||
|
||||
Returns:
|
||||
List of PromptMessages with restored multimodal content
|
||||
"""
|
||||
from core.file import file_manager
|
||||
|
||||
return [_restore_message_content(msg, file_manager) for msg in messages]
|
||||
|
||||
|
||||
def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
|
||||
"""Restore multimodal content in a single PromptMessage."""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@ -16,7 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
@ -51,7 +51,6 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
@ -88,7 +87,6 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptMessageContext,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -161,9 +159,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# Parse prompt template to separate static messages and context references
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages, context_refs, template_order = self._parse_prompt_template()
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
@ -211,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
@ -225,40 +220,21 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
# Get prompt messages
|
||||
prompt_messages: Sequence[PromptMessage]
|
||||
stop: Sequence[str] | None
|
||||
if isinstance(prompt_template, list) and context_refs:
|
||||
prompt_messages, stop = self._build_prompt_messages_with_context(
|
||||
context_refs=context_refs,
|
||||
template_order=template_order,
|
||||
static_messages=static_messages,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(
|
||||
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
self.node_data.prompt_template,
|
||||
),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
@ -274,7 +250,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@ -326,7 +301,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
@ -393,7 +367,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@ -417,7 +390,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
@ -609,212 +581,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_prompt_template(
|
||||
self,
|
||||
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
||||
"""
|
||||
Parse prompt_template to separate static messages and context references.
|
||||
|
||||
Returns:
|
||||
Tuple of (static_messages, context_refs, template_order)
|
||||
- static_messages: list of LLMNodeChatModelMessage
|
||||
- context_refs: list of PromptMessageContext
|
||||
- template_order: list of (index, type) tuples preserving original order
|
||||
"""
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages: list[LLMNodeChatModelMessage] = []
|
||||
context_refs: list[PromptMessageContext] = []
|
||||
template_order: list[tuple[int, str]] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for idx, item in enumerate(prompt_template):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
context_refs.append(item)
|
||||
template_order.append((idx, "context"))
|
||||
else:
|
||||
static_messages.append(item)
|
||||
template_order.append((idx, "static"))
|
||||
# Transform static messages for jinja2
|
||||
if static_messages:
|
||||
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
||||
|
||||
return static_messages, context_refs, template_order
|
||||
|
||||
def _build_prompt_messages_with_context(
|
||||
self,
|
||||
*,
|
||||
context_refs: list[PromptMessageContext],
|
||||
template_order: list[tuple[int, str]],
|
||||
static_messages: list[LLMNodeChatModelMessage],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context_files: list[File],
|
||||
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
||||
"""
|
||||
Build prompt messages by combining static messages and context references in DSL order.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_messages, stop_sequences)
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Process messages in DSL order: iterate once and handle each type directly
|
||||
combined_messages: list[PromptMessage] = []
|
||||
context_idx = 0
|
||||
static_idx = 0
|
||||
|
||||
for _, type_ in template_order:
|
||||
if type_ == "context":
|
||||
# Handle context reference
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
# Restore multimodal content (base64/url) that was truncated when saving context
|
||||
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
|
||||
combined_messages.extend(restored_messages)
|
||||
context_idx += 1
|
||||
else:
|
||||
# Handle static message
|
||||
static_msg = static_messages[static_idx]
|
||||
processed_msgs = LLMNode.handle_list_messages(
|
||||
messages=[static_msg],
|
||||
context=context,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(processed_msgs)
|
||||
static_idx += 1
|
||||
|
||||
# Append memory messages
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=self.node_data.memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
combined_messages.extend(memory_messages)
|
||||
|
||||
# Append current query if provided
|
||||
if query:
|
||||
query_message = LLMNodeChatModelMessage(
|
||||
text=query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
query_msgs = LLMNode.handle_list_messages(
|
||||
messages=[query_message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(query_msgs)
|
||||
|
||||
# Handle files (sys_files and context_files)
|
||||
combined_messages = self._append_files_to_messages(
|
||||
messages=combined_messages,
|
||||
sys_files=files,
|
||||
context_files=context_files,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Filter empty messages and get stop sequences
|
||||
combined_messages = self._filter_messages(combined_messages, model_config)
|
||||
stop = self._get_stop_sequences(model_config)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _append_files_to_messages(
|
||||
self,
|
||||
*,
|
||||
messages: list[PromptMessage],
|
||||
sys_files: Sequence[File],
|
||||
context_files: list[File],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
"""Append sys_files and context_files to messages."""
|
||||
vision_enabled = self.node_data.vision.enabled
|
||||
vision_detail = self.node_data.vision.configs.detail
|
||||
|
||||
# Handle sys_files (will be deprecated later)
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Handle context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
for file in context_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
return messages
|
||||
|
||||
def _filter_messages(
|
||||
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> list[PromptMessage]:
|
||||
"""Filter empty messages and unsupported content types."""
|
||||
filtered_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
filtered_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in message.content:
|
||||
# Skip non-text content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
feature_map = {
|
||||
PromptMessageContentType.IMAGE: ModelFeature.VISION,
|
||||
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
|
||||
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
|
||||
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
|
||||
}
|
||||
required_feature = feature_map.get(content_item.type)
|
||||
if required_feature and required_feature not in model_config.model_schema.features:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
|
||||
# Simplify single text content
|
||||
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
|
||||
message.content = filtered_content[0].data
|
||||
else:
|
||||
message.content = filtered_content
|
||||
|
||||
if not message.is_empty():
|
||||
filtered_messages.append(message)
|
||||
|
||||
if not filtered_messages:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_messages
|
||||
|
||||
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
|
||||
"""Get stop sequences from model config."""
|
||||
return model_config.stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@ -1012,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: BaseMemory | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@ -1571,7 +1337,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@ -1588,7 +1354,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
|
||||
@ -7,7 +7,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -145,10 +145,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
if (
|
||||
@ -246,10 +244,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = json.dumps(result, ensure_ascii=False)
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@ -258,7 +252,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
@ -306,7 +299,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@ -388,7 +381,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -426,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -460,7 +453,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@ -688,7 +681,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@ -715,7 +708,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@ -96,10 +96,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
@ -199,15 +197,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
@ -319,7 +312,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: BaseMemory | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@ -1,63 +1,11 @@
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Self, Union
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
# Pattern to match mention value format: {{@node.context@}}instruction
|
||||
# The placeholder {{@node.context@}} must appear at the beginning
|
||||
# Format: {{@agent_node_id.context@}} where agent_node_id is dynamic, context is fixed
|
||||
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
|
||||
|
||||
|
||||
def parse_mention_value(value: str) -> tuple[str, str]:
|
||||
"""Parse mention value into (node_id, instruction).
|
||||
|
||||
Args:
|
||||
value: The mention value string like "{{@llm.context@}}extract keywords"
|
||||
|
||||
Returns:
|
||||
Tuple of (node_id, instruction)
|
||||
|
||||
Raises:
|
||||
ValueError: If value format is invalid
|
||||
"""
|
||||
match = MENTION_VALUE_PATTERN.match(value)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"For mention type, value must start with {{@node.context@}} placeholder, "
|
||||
"e.g., '{{@llm.context@}}extract keywords'"
|
||||
)
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
|
||||
class MentionConfig(BaseModel):
|
||||
"""Configuration for extracting value from context variable.
|
||||
|
||||
Used when a tool parameter needs to be extracted from list[PromptMessage]
|
||||
context using an extractor LLM node.
|
||||
|
||||
Note: instruction is embedded in the value field as "{{@node.context@}}instruction"
|
||||
"""
|
||||
|
||||
# ID of the extractor LLM node
|
||||
extractor_node_id: str
|
||||
|
||||
# Output variable selector from extractor node
|
||||
# e.g., ["text"], ["structured_output", "query"]
|
||||
output_selector: Sequence[str]
|
||||
|
||||
# Strategy when output is None
|
||||
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
|
||||
|
||||
# Default value when null_strategy is "use_default"
|
||||
# Type should match the parameter's expected type
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@ -87,9 +35,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant", "mention"]
|
||||
# Required config for mention type, extracting value from context variable
|
||||
mention_config: MentionConfig | None = None
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@ -102,9 +48,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "mention":
|
||||
# Skip here, will be validated in model_validator
|
||||
pass
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
@ -115,26 +58,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return typ
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mention_type(self) -> Self:
|
||||
"""Validate mention type with mention_config."""
|
||||
if self.type != "mention":
|
||||
return self
|
||||
|
||||
value = self.value
|
||||
if value is None:
|
||||
return self
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("value must be a string for mention type")
|
||||
# For mention type, value must match format: {{@node.context@}}instruction
|
||||
# This will raise ValueError if format is invalid
|
||||
parse_mention_value(value)
|
||||
# mention_config is required for mention type
|
||||
if self.mention_config is None:
|
||||
raise ValueError("mention_config is required for mention type")
|
||||
return self
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
|
||||
@ -1,10 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@ -187,7 +184,6 @@ class ToolNode(Node[ToolNodeData]):
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
for_log (bool): Whether to generate parameters for logging.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
@ -203,37 +199,14 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
|
||||
selector = tool_input.value
|
||||
variable = variable_pool.get(selector)
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type: get value from extractor node's output
|
||||
if tool_input.mention_config is None:
|
||||
raise ToolParameterError(
|
||||
f"mention_config is required for mention type parameter '{parameter_name}'"
|
||||
)
|
||||
mention_config = tool_input.mention_config.model_dump()
|
||||
try:
|
||||
parameter_value, found = variable_pool.resolve_mention(
|
||||
mention_config, parameter_name=parameter_name
|
||||
)
|
||||
if not found and parameter.required:
|
||||
raise ToolParameterError(
|
||||
f"Extractor output not found for required parameter '{parameter_name}'"
|
||||
)
|
||||
if not found:
|
||||
continue
|
||||
except ValueError as e:
|
||||
raise ToolParameterError(str(e)) from e
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
template = str(tool_input.value)
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
@ -515,12 +488,8 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "mention":
|
||||
# Mention type: value is handled by extractor node, no direct variable reference
|
||||
pass
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
|
||||
@ -268,58 +268,6 @@ class VariablePool(BaseModel):
|
||||
continue
|
||||
self.add(selector, value)
|
||||
|
||||
def resolve_mention(
|
||||
self,
|
||||
mention_config: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
parameter_name: str = "",
|
||||
) -> tuple[Any, bool]:
|
||||
"""
|
||||
Resolve a mention parameter value from an extractor node's output.
|
||||
|
||||
Mention parameters reference values extracted by an extractor LLM node
|
||||
from list[PromptMessage] context.
|
||||
|
||||
Args:
|
||||
mention_config: A dict containing:
|
||||
- extractor_node_id: ID of the extractor LLM node
|
||||
- output_selector: Selector path for the output variable (e.g., ["text"])
|
||||
- null_strategy: "raise_error" or "use_default"
|
||||
- default_value: Value to use when null_strategy is "use_default"
|
||||
parameter_name: Name of the parameter being resolved (for error messages)
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_value, found):
|
||||
- resolved_value: The extracted value, or default_value if not found
|
||||
- found: True if value was found, False if using default
|
||||
|
||||
Raises:
|
||||
ValueError: If extractor_node_id is missing, or if null_strategy is
|
||||
"raise_error" and the value is not found
|
||||
"""
|
||||
extractor_node_id = mention_config.get("extractor_node_id")
|
||||
if not extractor_node_id:
|
||||
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
|
||||
|
||||
output_selector = list(mention_config.get("output_selector", []))
|
||||
null_strategy = mention_config.get("null_strategy", "raise_error")
|
||||
default_value = mention_config.get("default_value")
|
||||
|
||||
# Build full selector: [extractor_node_id, ...output_selector]
|
||||
full_selector = [extractor_node_id] + output_selector
|
||||
variable = self.get(full_selector)
|
||||
|
||||
if variable is None:
|
||||
if null_strategy == "use_default":
|
||||
return default_value, False
|
||||
raise ValueError(
|
||||
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
|
||||
f"not found for parameter '{parameter_name}'"
|
||||
)
|
||||
|
||||
return variable.value, True
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
|
||||
@ -4,7 +4,6 @@ from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@ -12,7 +11,6 @@ from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@ -31,7 +29,6 @@ from core.variables.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
@ -64,7 +61,6 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
@ -160,13 +156,7 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, PromptMessage):
|
||||
# Single PromptMessage should be wrapped in a list
|
||||
return ArrayPromptMessageSegment(value=[value])
|
||||
if isinstance(value, list):
|
||||
# Check if all items are PromptMessage
|
||||
if value and all(isinstance(item, PromptMessage) for item in value):
|
||||
return ArrayPromptMessageSegment(value=value)
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if all(isinstance(item, ArraySegment) for item in items):
|
||||
@ -210,7 +200,6 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
|
||||
}
|
||||
|
||||
|
||||
@ -285,10 +274,6 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
):
|
||||
segment_class = _segment_factory[inferred_type]
|
||||
return segment_class(value_type=inferred_type, value=value)
|
||||
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
||||
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
||||
segment_class = _segment_factory[segment_type]
|
||||
return segment_class(value_type=segment_type, value=value)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||
|
||||
|
||||
@ -1,74 +0,0 @@
|
||||
"""
|
||||
Workspace permission helper functions.
|
||||
|
||||
These helpers check both billing/plan level and workspace-specific policy level permissions.
|
||||
Checks are performed at two levels:
|
||||
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
|
||||
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def check_workspace_member_invite_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows member invitations at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - For future expansion (currently no plan-level restriction)
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits member invitations
|
||||
"""
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_member_invite:
|
||||
raise Forbidden("Workspace policy prohibits member invitations")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
|
||||
|
||||
|
||||
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
|
||||
"""
|
||||
Check if workspace allows owner transfer at both billing and policy levels.
|
||||
|
||||
Checks performed:
|
||||
1. Billing/plan level - SANDBOX plan blocks owner transfer
|
||||
2. Enterprise policy level - Admin-configured workspace permission
|
||||
|
||||
Args:
|
||||
workspace_id: The workspace ID to check permissions for
|
||||
|
||||
Raises:
|
||||
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
|
||||
"""
|
||||
features = FeatureService.get_features(workspace_id)
|
||||
if not features.is_allow_transfer_workspace:
|
||||
raise Forbidden("Your current plan does not allow workspace ownership transfer")
|
||||
|
||||
# Check enterprise workspace policy level (only if enterprise enabled)
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
try:
|
||||
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
|
||||
if not permission.allow_owner_transfer:
|
||||
raise Forbidden("Workspace policy prohibits ownership transfer")
|
||||
except Forbidden:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)
|
||||
@ -1285,7 +1285,7 @@ class WorkflowDraftVariable(Base):
|
||||
# which may differ from the original value's type. Typically, they are the same,
|
||||
# but in cases where the structurally truncated value still exceeds the size limit,
|
||||
# text slicing is applied, and the `value_type` is converted to `STRING`.
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
#
|
||||
@ -1659,7 +1659,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# The `value_type` field records the type of the original value.
|
||||
value_type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=21),
|
||||
EnumText(SegmentType, length=20),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
@ -1381,11 +1381,6 @@ class RegisterService:
|
||||
normalized_email = email.lower()
|
||||
|
||||
"""Invite new member"""
|
||||
# Check workspace permission for member invitations
|
||||
from libs.workspace_permission import check_workspace_member_invite_permission
|
||||
|
||||
check_workspace_member_invite_permission(tenant.id)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
|
||||
|
||||
|
||||
@ -13,23 +13,6 @@ class WebAppSettings(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class WorkspacePermission(BaseModel):
|
||||
workspace_id: str = Field(
|
||||
description="The ID of the workspace.",
|
||||
alias="workspaceId",
|
||||
)
|
||||
allow_member_invite: bool = Field(
|
||||
description="Whether to allow members to invite new members to the workspace.",
|
||||
default=False,
|
||||
alias="allowMemberInvite",
|
||||
)
|
||||
allow_owner_transfer: bool = Field(
|
||||
description="Whether to allow owners to transfer ownership of the workspace.",
|
||||
default=False,
|
||||
alias="allowOwnerTransfer",
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseService:
|
||||
@classmethod
|
||||
def get_info(cls):
|
||||
@ -61,16 +44,6 @@ class EnterpriseService:
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Invalid date format: {data}") from e
|
||||
|
||||
class WorkspacePermissionService:
|
||||
@classmethod
|
||||
def get_permission(cls, workspace_id: str):
|
||||
if not workspace_id:
|
||||
raise ValueError("workspace_id must be provided.")
|
||||
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
|
||||
if not data or "permission" not in data:
|
||||
raise ValueError("No data found.")
|
||||
return WorkspacePermission.model_validate(data["permission"])
|
||||
|
||||
class WebAppAuth:
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
|
||||
|
||||
@ -12,6 +12,8 @@ logger = logging.getLogger(__name__)
|
||||
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
|
||||
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
|
||||
|
||||
TASK_TYPE_SYNC_TO_WORKSPACE = "sync_to_workspace"
|
||||
|
||||
|
||||
class WorkspaceSyncService:
|
||||
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
|
||||
@ -38,6 +40,7 @@ class WorkspaceSyncService:
|
||||
"retry_count": 0,
|
||||
"created_at": datetime.now(UTC).isoformat(),
|
||||
"source": source,
|
||||
"type": TASK_TYPE_SYNC_TO_WORKSPACE,
|
||||
}
|
||||
|
||||
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
|
||||
|
||||
@ -7,7 +7,6 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@ -288,10 +287,6 @@ class VariableTruncator(BaseTruncator):
|
||||
if isinstance(item, File):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
# Handle PromptMessage types - convert to dict for truncation
|
||||
if isinstance(item, PromptMessage):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
if i >= target_length:
|
||||
return _PartResult(truncated_value, used_size, True)
|
||||
if i > 0:
|
||||
|
||||
@ -163,29 +163,3 @@ class WorkflowScheduleCFSPlanEntity(BaseModel):
|
||||
|
||||
schedule_strategy: Strategy
|
||||
granularity: int = Field(default=-1) # -1 means infinite
|
||||
|
||||
|
||||
# ========== Mention Graph Entities ==========
|
||||
|
||||
|
||||
class MentionParameterSchema(BaseModel):
|
||||
"""Schema for the parameter to be extracted from mention context."""
|
||||
|
||||
name: str = Field(description="Parameter name (e.g., 'query')")
|
||||
type: str = Field(default="string", description="Parameter type (e.g., 'string', 'number')")
|
||||
description: str = Field(default="", description="Parameter description for LLM")
|
||||
|
||||
|
||||
class MentionGraphRequest(BaseModel):
|
||||
"""Request payload for generating mention graph."""
|
||||
|
||||
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
|
||||
parameter_key: str = Field(description="Key of the parameter being extracted")
|
||||
context_source: list[str] = Field(description="Variable selector for the context source")
|
||||
parameter_schema: MentionParameterSchema = Field(description="Schema of the parameter to extract")
|
||||
|
||||
|
||||
class MentionGraphResponse(BaseModel):
|
||||
"""Response containing the generated mention graph."""
|
||||
|
||||
graph: Mapping[str, Any] = Field(description="Complete graph structure with nodes, edges, viewport")
|
||||
|
||||
@ -1,143 +0,0 @@
|
||||
"""
|
||||
Service for generating Mention LLM node graph structures.
|
||||
|
||||
This service creates graph structures containing LLM nodes configured for
|
||||
extracting values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.model_runtime.entities import LLMMode
|
||||
from core.workflow.enums import NodeType
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from services.workflow.entities import MentionGraphRequest, MentionGraphResponse, MentionParameterSchema
|
||||
|
||||
|
||||
class MentionGraphService:
|
||||
"""Service for generating Mention LLM node graph structures."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def generate_mention_node_id(self, node_id: str, parameter_name: str) -> str:
|
||||
"""Generate mention node ID following the naming convention.
|
||||
|
||||
Format: {node_id}_ext_{parameter_name}
|
||||
"""
|
||||
return f"{node_id}_ext_{parameter_name}"
|
||||
|
||||
def generate_mention_graph(self, tenant_id: str, request: MentionGraphRequest) -> MentionGraphResponse:
|
||||
"""Generate a complete graph structure containing a Mention LLM node.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID for fetching default model config
|
||||
request: The mention graph generation request
|
||||
|
||||
Returns:
|
||||
Complete graph structure with nodes, edges, and viewport
|
||||
"""
|
||||
node_id = self.generate_mention_node_id(request.parent_node_id, request.parameter_key)
|
||||
model_config = self._get_default_model_config(tenant_id)
|
||||
node = self._build_mention_llm_node(
|
||||
node_id=node_id,
|
||||
parent_node_id=request.parent_node_id,
|
||||
context_source=request.context_source,
|
||||
parameter_schema=request.parameter_schema,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
graph = {
|
||||
"nodes": [node],
|
||||
"edges": [],
|
||||
"viewport": {},
|
||||
}
|
||||
|
||||
return MentionGraphResponse(graph=graph)
|
||||
|
||||
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
|
||||
"""Get the default LLM model configuration for the tenant."""
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
if default_model:
|
||||
return {
|
||||
"provider": default_model.provider.provider,
|
||||
"name": default_model.model,
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
# Fallback to empty config if no default model is configured
|
||||
return {
|
||||
"provider": "",
|
||||
"name": "",
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
def _build_mention_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
parent_node_id: str,
|
||||
context_source: list[str],
|
||||
parameter_schema: MentionParameterSchema,
|
||||
model_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build the Mention LLM node structure.
|
||||
|
||||
The node uses:
|
||||
- $context in prompt_template to reference the PromptMessage list
|
||||
- structured_output for extracting the specific parameter
|
||||
- parent_node_id to associate with the parent node
|
||||
"""
|
||||
prompt_template = [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "Extract the required parameter value from the conversation context above.",
|
||||
},
|
||||
{"$context": context_source},
|
||||
{"role": "user", "text": ""},
|
||||
]
|
||||
|
||||
structured_output = {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
parameter_schema.name: {
|
||||
"type": parameter_schema.type,
|
||||
"description": parameter_schema.description,
|
||||
}
|
||||
},
|
||||
"required": [parameter_schema.name],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"id": node_id,
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"type": NodeType.LLM.value,
|
||||
"title": f"Mention: {parameter_schema.name}",
|
||||
"desc": f"Extract {parameter_schema.name} from conversation context",
|
||||
"parent_node_id": parent_node_id,
|
||||
"model": model_config,
|
||||
"prompt_template": prompt_template,
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": False,
|
||||
},
|
||||
"memory": None,
|
||||
"structured_output_enabled": True,
|
||||
"structured_output": structured_output,
|
||||
},
|
||||
}
|
||||
@ -83,30 +83,7 @@
|
||||
<p class="content1">Dear {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
|
||||
<p class="content2">Click the button below to log in to Dify and join the workspace.</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Login Here</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
|
||||
<p class="content2">Best regards,</p>
|
||||
<p class="content2">Dify Team</p>
|
||||
</div>
|
||||
|
||||
@ -83,30 +83,7 @@
|
||||
<p class="content1">尊敬的 {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
|
||||
<p class="content2">点击下方按钮即可登录 Dify 并且加入空间。</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">在此登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
|
||||
<p class="content2">此致,</p>
|
||||
<p class="content2">Dify 团队</p>
|
||||
</div>
|
||||
|
||||
@ -115,30 +115,7 @@
|
||||
We noticed you tried to sign up, but this email is already registered with an existing account.
|
||||
|
||||
Please log in here: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Log In</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">Log In</a>
|
||||
<p class="description">
|
||||
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
|
||||
class="reset-btn">Reset Password</a>
|
||||
|
||||
@ -115,30 +115,7 @@
|
||||
我们注意到您尝试注册,但此电子邮件已注册。
|
||||
|
||||
请在此登录: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">登录</a>
|
||||
<p class="description">
|
||||
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
|
||||
</p>
|
||||
|
||||
@ -92,34 +92,12 @@
|
||||
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
|
||||
create, and collaborate to build and operate AI applications.</p>
|
||||
<p class="content2">Click the button below to log in to {{application_title}} and join the workspace.</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Login Here</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none"
|
||||
class="button" href="{{ url }}">Login Here</a></p>
|
||||
<p class="content2">Best regards,</p>
|
||||
<p class="content2">{{application_title}} Team</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
|
||||
</html>
|
||||
</html>
|
||||
@ -81,30 +81,7 @@
|
||||
<p class="content1">尊敬的 {{ to }},</p>
|
||||
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
|
||||
<p class="content2">点击下方按钮即可登录 {{application_title}} 并且加入空间。</p>
|
||||
<div style="text-align: center; margin-bottom: 32px;">
|
||||
<a href="{{ url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">在此登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
|
||||
<p class="content2">此致,</p>
|
||||
<p class="content2">{{application_title}} 团队</p>
|
||||
</div>
|
||||
|
||||
@ -111,30 +111,7 @@
|
||||
We noticed you tried to sign up, but this email is already registered with an existing account.
|
||||
|
||||
Please log in here: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">Log In</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
If the button doesn't work, copy and paste this link into your browser:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">Log In</a>
|
||||
<p class="description">
|
||||
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
|
||||
class="reset-btn">Reset Password</a>
|
||||
|
||||
@ -111,30 +111,7 @@
|
||||
我们注意到您尝试注册,但此电子邮件已注册。
|
||||
|
||||
请在此登录: </p>
|
||||
<div style="text-align: center; margin-bottom: 20px;">
|
||||
<a href="{{ login_url }}"
|
||||
style="background-color:#2563eb;
|
||||
color:#ffffff !important;
|
||||
text-decoration:none;
|
||||
display:inline-block;
|
||||
font-weight:600;
|
||||
border-radius:4px;
|
||||
font-size:14px;
|
||||
line-height:18px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
text-align:center;
|
||||
border-top: 10px solid #2563eb;
|
||||
border-bottom: 10px solid #2563eb;
|
||||
border-left: 20px solid #2563eb;
|
||||
border-right: 20px solid #2563eb;
|
||||
">登录</a>
|
||||
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
|
||||
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
|
||||
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
|
||||
{{ login_url }}
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
<a href="{{ login_url }}" class="button">登录</a>
|
||||
<p class="description">
|
||||
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
|
||||
</p>
|
||||
|
||||
181
api/tests/fixtures/file output schema.yml
vendored
181
api/tests/fixtures/file output schema.yml
vendored
@ -1,181 +0,0 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: file output schema
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: true
|
||||
fileUploadConfig:
|
||||
attachment_image_file_size_limit: 2
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
file_upload_limit: 10
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1768292241666-llm
|
||||
source: '1768292241666'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
sourceType: llm
|
||||
targetType: answer
|
||||
id: llm-answer
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1768292241666'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: '{{#sys.query#}}
|
||||
|
||||
|
||||
{{#sys.files#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: false
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
prompt_template:
|
||||
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
|
||||
role: system
|
||||
text: ''
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
configs:
|
||||
detail: high
|
||||
variable_selector:
|
||||
- sys
|
||||
- files
|
||||
enabled: true
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '{{#llm.structured_output.image#}}'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -149
|
||||
y: 97.5
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
307
api/tests/fixtures/pav-test-extraction.yml
vendored
307
api/tests/fixtures/pav-test-extraction.yml
vendored
@ -1,307 +0,0 @@
|
||||
app:
|
||||
description: Test for variable extraction feature
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: pav-test-extraction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1767773675796-llm
|
||||
source: '1767773675796'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: llm
|
||||
targetType: tool
|
||||
id: llm-source-1767773709491-target
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: '1767773709491'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: tool
|
||||
targetType: answer
|
||||
id: tool-source-answer-target
|
||||
source: '1767773709491'
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1767773675796'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
mode: node
|
||||
query_prompt_template: '{{#sys.query#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: true
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
prompt_template:
|
||||
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
|
||||
role: system
|
||||
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
|
||||
|
||||
请与用户进行对话,了解他们的搜索需求。
|
||||
|
||||
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
is_team_authorization: true
|
||||
paramSchemas:
|
||||
- auto_generate: null
|
||||
default: null
|
||||
form: llm
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
ja_JP: used for searching
|
||||
pt_BR: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
label:
|
||||
en_US: Query string
|
||||
ja_JP: Query string
|
||||
pt_BR: Query string
|
||||
zh_Hans: 查询语句
|
||||
llm_description: key words for searching
|
||||
max: null
|
||||
min: null
|
||||
name: query
|
||||
options: []
|
||||
placeholder: null
|
||||
precision: null
|
||||
required: true
|
||||
scope: null
|
||||
template: null
|
||||
type: string
|
||||
params:
|
||||
query: ''
|
||||
plugin_id: langgenius/google
|
||||
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
|
||||
provider_id: langgenius/google/google
|
||||
provider_name: langgenius/google/google
|
||||
provider_type: builtin
|
||||
selected: false
|
||||
title: GoogleSearch
|
||||
tool_configurations: {}
|
||||
tool_description: A tool for performing a Google SERP search and extracting
|
||||
snippets and webpages.Input should be a search query.
|
||||
tool_label: GoogleSearch
|
||||
tool_name: google_search
|
||||
tool_node_version: '2'
|
||||
tool_parameters:
|
||||
query:
|
||||
type: mention
|
||||
value: '{{@llm.context@}}请从对话历史中提取用户想要搜索的关键词,只返回关键词本身'
|
||||
mention_config:
|
||||
extractor_node_id: 1767773709491_ext_query
|
||||
output_selector:
|
||||
- structured_output
|
||||
- query
|
||||
null_strategy: use_default
|
||||
default_value: ''
|
||||
type: tool
|
||||
height: 52
|
||||
id: '1767773709491'
|
||||
position:
|
||||
x: 682
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 682
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
parent_node_id: '1767773709491'
|
||||
prompt_template:
|
||||
- $context:
|
||||
- llm
|
||||
- context
|
||||
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
|
||||
- id: 18ba6710-77f5-47f4-b144-9191833bb547
|
||||
role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
query:
|
||||
description: 搜索的关键词
|
||||
type: string
|
||||
required:
|
||||
- query
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: 提取搜索关键词
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: 1767773709491_ext_query
|
||||
position:
|
||||
x: 531
|
||||
y: 382
|
||||
positionAbsolute:
|
||||
x: 531
|
||||
y: 382
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '搜索结果:
|
||||
|
||||
{{#1767773709491.text#}}
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 984
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 984
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -151
|
||||
y: 123
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
@ -1,182 +0,0 @@
|
||||
"""Tests for file_manager module, specifically multimodal content handling."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.file.file_manager import (
|
||||
_encode_file_ref,
|
||||
restore_multimodal_content,
|
||||
to_prompt_message_content,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
class TestEncodeFileRef:
|
||||
"""Tests for _encode_file_ref function."""
|
||||
|
||||
def test_encodes_local_file(self):
|
||||
"""Local file should be encoded as 'local:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="abc123",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "local:abc123"
|
||||
|
||||
def test_encodes_tool_file(self):
|
||||
"""Tool file should be encoded as 'tool:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="xyz789",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "tool:xyz789"
|
||||
|
||||
def test_encodes_remote_url(self):
|
||||
"""Remote URL should be encoded as 'remote:url'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
storage_key="",
|
||||
)
|
||||
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
|
||||
|
||||
|
||||
class TestToPromptMessageContent:
|
||||
"""Tests for to_prompt_message_content function with file_ref field."""
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_includes_file_ref(self, mock_get_encoded, mock_config):
|
||||
"""Generated content should include file_ref field."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_get_encoded.return_value = "base64data"
|
||||
|
||||
file = File(
|
||||
id="test-message-file-id",
|
||||
tenant_id="test-tenant",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test-related-id",
|
||||
remote_url=None,
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
storage_key="test-key",
|
||||
)
|
||||
|
||||
result = to_prompt_message_content(file)
|
||||
|
||||
assert isinstance(result, ImagePromptMessageContent)
|
||||
assert result.file_ref == "local:test-related-id"
|
||||
assert result.base64_data == "base64data"
|
||||
|
||||
|
||||
class TestRestoreMultimodalContent:
|
||||
"""Tests for restore_multimodal_content function."""
|
||||
|
||||
def test_returns_content_unchanged_when_no_file_ref(self):
|
||||
"""Content without file_ref should pass through unchanged."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref=None,
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_data(self):
|
||||
"""Content that already has base64_data should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_url(self):
|
||||
"""Content that already has url should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._to_url")
|
||||
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
|
||||
"""Content should be restored from file_ref when url is empty (url mode)."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_to_url.return_value = "https://restored-url.com/image.png"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://restored-url.com/image.png"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
|
||||
"""Content should be restored as base64 when in base64 mode."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_get_encoded.return_value = "restored-base64-data"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "restored-base64-data"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
def test_handles_invalid_file_ref_gracefully(self):
|
||||
"""Invalid file_ref format should be handled gracefully."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
file_ref="invalid_format_no_colon",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
# Should return unchanged on error
|
||||
assert result.base64_data == ""
|
||||
@ -1,269 +0,0 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
@ -1 +0,0 @@
|
||||
"""Tests for workflow context management."""
|
||||
@ -1,258 +0,0 @@
|
||||
"""Tests for execution context module."""
|
||||
|
||||
import contextvars
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.context.execution_context import (
|
||||
AppContext,
|
||||
ExecutionContext,
|
||||
ExecutionContextBuilder,
|
||||
IExecutionContext,
|
||||
NullAppContext,
|
||||
)
|
||||
|
||||
|
||||
class TestAppContext:
|
||||
"""Test AppContext abstract base class."""
|
||||
|
||||
def test_app_context_is_abstract(self):
|
||||
"""Test that AppContext cannot be instantiated directly."""
|
||||
with pytest.raises(TypeError):
|
||||
AppContext() # type: ignore
|
||||
|
||||
|
||||
class TestNullAppContext:
|
||||
"""Test NullAppContext implementation."""
|
||||
|
||||
def test_null_app_context_get_config(self):
|
||||
"""Test get_config returns value from config dict."""
|
||||
config = {"key1": "value1", "key2": "value2"}
|
||||
ctx = NullAppContext(config=config)
|
||||
|
||||
assert ctx.get_config("key1") == "value1"
|
||||
assert ctx.get_config("key2") == "value2"
|
||||
|
||||
def test_null_app_context_get_config_default(self):
|
||||
"""Test get_config returns default when key not found."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
assert ctx.get_config("nonexistent", "default") == "default"
|
||||
assert ctx.get_config("nonexistent") is None
|
||||
|
||||
def test_null_app_context_get_extension(self):
|
||||
"""Test get_extension returns stored extension."""
|
||||
ctx = NullAppContext()
|
||||
extension = MagicMock()
|
||||
ctx.set_extension("db", extension)
|
||||
|
||||
assert ctx.get_extension("db") == extension
|
||||
|
||||
def test_null_app_context_get_extension_not_found(self):
|
||||
"""Test get_extension returns None when extension not found."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
assert ctx.get_extension("nonexistent") is None
|
||||
|
||||
def test_null_app_context_enter_yield(self):
|
||||
"""Test enter method yields without any side effects."""
|
||||
ctx = NullAppContext()
|
||||
|
||||
with ctx.enter():
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Test ExecutionContext class."""
|
||||
|
||||
def test_initialization_with_all_params(self):
|
||||
"""Test ExecutionContext initialization with all parameters."""
|
||||
app_ctx = NullAppContext()
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = ExecutionContext(
|
||||
app_context=app_ctx,
|
||||
context_vars=context_vars,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_initialization_with_minimal_params(self):
|
||||
"""Test ExecutionContext initialization with minimal parameters."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
assert ctx.app_context is None
|
||||
assert ctx.context_vars is None
|
||||
assert ctx.user is None
|
||||
|
||||
def test_enter_with_context_vars(self):
|
||||
"""Test enter restores context variables."""
|
||||
test_var = contextvars.ContextVar("test_var")
|
||||
test_var.set("original_value")
|
||||
|
||||
# Copy context with the variable
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Change the variable
|
||||
test_var.set("new_value")
|
||||
|
||||
# Create execution context and enter it
|
||||
ctx = ExecutionContext(context_vars=context_vars)
|
||||
|
||||
with ctx.enter():
|
||||
# Variable should be restored to original value
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
# After exiting, variable stays at the value from within the context
|
||||
# (this is expected Python contextvars behavior)
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
def test_enter_with_app_context(self):
|
||||
"""Test enter enters app context if available."""
|
||||
app_ctx = NullAppContext()
|
||||
ctx = ExecutionContext(app_context=app_ctx)
|
||||
|
||||
# Should not raise any exception
|
||||
with ctx.enter():
|
||||
pass
|
||||
|
||||
def test_enter_without_app_context(self):
|
||||
"""Test enter works without app context."""
|
||||
ctx = ExecutionContext(app_context=None)
|
||||
|
||||
# Should not raise any exception
|
||||
with ctx.enter():
|
||||
pass
|
||||
|
||||
def test_context_manager_protocol(self):
|
||||
"""Test ExecutionContext supports context manager protocol."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
with ctx:
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
def test_user_property(self):
|
||||
"""Test user property returns set user."""
|
||||
user = MagicMock()
|
||||
ctx = ExecutionContext(user=user)
|
||||
|
||||
assert ctx.user == user
|
||||
|
||||
|
||||
class TestIExecutionContextProtocol:
|
||||
"""Test IExecutionContext protocol."""
|
||||
|
||||
def test_execution_context_implements_protocol(self):
|
||||
"""Test that ExecutionContext implements IExecutionContext protocol."""
|
||||
ctx = ExecutionContext()
|
||||
|
||||
# Should have __enter__ and __exit__ methods
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
assert hasattr(ctx, "user")
|
||||
|
||||
def test_protocol_compatibility(self):
|
||||
"""Test that ExecutionContext can be used where IExecutionContext is expected."""
|
||||
|
||||
def accept_context(context: IExecutionContext) -> Any:
|
||||
"""Function that accepts IExecutionContext protocol."""
|
||||
# Just verify it has the required protocol attributes
|
||||
assert hasattr(context, "__enter__")
|
||||
assert hasattr(context, "__exit__")
|
||||
assert hasattr(context, "user")
|
||||
return context.user
|
||||
|
||||
ctx = ExecutionContext(user="test_user")
|
||||
result = accept_context(ctx)
|
||||
|
||||
assert result == "test_user"
|
||||
|
||||
def test_protocol_with_flask_execution_context(self):
|
||||
"""Test that IExecutionContext protocol is compatible with different implementations."""
|
||||
# Verify the protocol works with ExecutionContext
|
||||
ctx = ExecutionContext(user="test_user")
|
||||
|
||||
# Should have the required protocol attributes
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
assert hasattr(ctx, "user")
|
||||
assert ctx.user == "test_user"
|
||||
|
||||
# Should work as context manager
|
||||
with ctx:
|
||||
assert ctx.user == "test_user"
|
||||
|
||||
|
||||
class TestExecutionContextBuilder:
|
||||
"""Test ExecutionContextBuilder class."""
|
||||
|
||||
def test_builder_with_all_params(self):
|
||||
"""Test builder with all parameters set."""
|
||||
app_ctx = NullAppContext()
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = (
|
||||
ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build()
|
||||
)
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_builder_with_partial_params(self):
|
||||
"""Test builder with only some parameters set."""
|
||||
app_ctx = NullAppContext()
|
||||
|
||||
ctx = ExecutionContextBuilder().with_app_context(app_ctx).build()
|
||||
|
||||
assert ctx.app_context == app_ctx
|
||||
assert ctx.context_vars is None
|
||||
assert ctx.user is None
|
||||
|
||||
def test_builder_fluent_interface(self):
|
||||
"""Test builder provides fluent interface."""
|
||||
builder = ExecutionContextBuilder()
|
||||
|
||||
# Each method should return the builder
|
||||
assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder)
|
||||
assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder)
|
||||
assert isinstance(builder.with_user(None), ExecutionContextBuilder)
|
||||
|
||||
|
||||
class TestCaptureCurrentContext:
|
||||
"""Test capture_current_context function."""
|
||||
|
||||
def test_capture_current_context_returns_context(self):
|
||||
"""Test that capture_current_context returns a valid context."""
|
||||
from core.workflow.context.execution_context import capture_current_context
|
||||
|
||||
result = capture_current_context()
|
||||
|
||||
# Should return an object that implements IExecutionContext
|
||||
assert hasattr(result, "__enter__")
|
||||
assert hasattr(result, "__exit__")
|
||||
assert hasattr(result, "user")
|
||||
|
||||
def test_capture_current_context_captures_contextvars(self):
|
||||
"""Test that capture_current_context captures context variables."""
|
||||
# Set a context variable before capturing
|
||||
import contextvars
|
||||
|
||||
test_var = contextvars.ContextVar("capture_test_var")
|
||||
test_var.set("test_value_123")
|
||||
|
||||
from core.workflow.context.execution_context import capture_current_context
|
||||
|
||||
result = capture_current_context()
|
||||
|
||||
# Context variables should be captured
|
||||
assert result.context_vars is not None
|
||||
@ -1,316 +0,0 @@
|
||||
"""Tests for Flask app context module."""
|
||||
|
||||
import contextvars
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestFlaskAppContext:
|
||||
"""Test FlaskAppContext implementation."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app."""
|
||||
app = MagicMock()
|
||||
app.config = {"TEST_KEY": "test_value"}
|
||||
app.extensions = {"db": MagicMock(), "cache": MagicMock()}
|
||||
app.app_context = MagicMock()
|
||||
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return app
|
||||
|
||||
def test_flask_app_context_initialization(self, mock_flask_app):
|
||||
"""Test FlaskAppContext initialization."""
|
||||
# Import here to avoid Flask dependency in test environment
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.flask_app == mock_flask_app
|
||||
|
||||
def test_flask_app_context_get_config(self, mock_flask_app):
|
||||
"""Test get_config returns Flask app config value."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_config("TEST_KEY") == "test_value"
|
||||
|
||||
def test_flask_app_context_get_config_default(self, mock_flask_app):
|
||||
"""Test get_config returns default when key not found."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_config("NONEXISTENT", "default") == "default"
|
||||
|
||||
def test_flask_app_context_get_extension(self, mock_flask_app):
|
||||
"""Test get_extension returns Flask extension."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
db_ext = mock_flask_app.extensions["db"]
|
||||
|
||||
assert ctx.get_extension("db") == db_ext
|
||||
|
||||
def test_flask_app_context_get_extension_not_found(self, mock_flask_app):
|
||||
"""Test get_extension returns None when extension not found."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
assert ctx.get_extension("nonexistent") is None
|
||||
|
||||
def test_flask_app_context_enter(self, mock_flask_app):
|
||||
"""Test enter method enters Flask app context."""
|
||||
from context.flask_app_context import FlaskAppContext
|
||||
|
||||
ctx = FlaskAppContext(mock_flask_app)
|
||||
|
||||
with ctx.enter():
|
||||
# Should not raise any exception
|
||||
pass
|
||||
|
||||
# Verify app_context was called
|
||||
mock_flask_app.app_context.assert_called_once()
|
||||
|
||||
|
||||
class TestFlaskExecutionContext:
|
||||
"""Test FlaskExecutionContext class."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app."""
|
||||
app = MagicMock()
|
||||
app.config = {}
|
||||
app.app_context = MagicMock()
|
||||
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
|
||||
return app
|
||||
|
||||
def test_initialization(self, mock_flask_app):
|
||||
"""Test FlaskExecutionContext initialization."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
context_vars = contextvars.copy_context()
|
||||
user = MagicMock()
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=context_vars,
|
||||
user=user,
|
||||
)
|
||||
|
||||
assert ctx.context_vars == context_vars
|
||||
assert ctx.user == user
|
||||
|
||||
def test_app_context_property(self, mock_flask_app):
|
||||
"""Test app_context property returns FlaskAppContext."""
|
||||
from context.flask_app_context import FlaskAppContext, FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
assert isinstance(ctx.app_context, FlaskAppContext)
|
||||
assert ctx.app_context.flask_app == mock_flask_app
|
||||
|
||||
def test_context_manager_protocol(self, mock_flask_app):
|
||||
"""Test FlaskExecutionContext supports context manager protocol."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
# Should have __enter__ and __exit__ methods
|
||||
assert hasattr(ctx, "__enter__")
|
||||
assert hasattr(ctx, "__exit__")
|
||||
|
||||
# Should work as context manager
|
||||
with ctx:
|
||||
pass
|
||||
|
||||
|
||||
class TestCaptureFlaskContext:
|
||||
"""Test capture_flask_context function."""
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures Flask app."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
assert ctx._flask_app == mock_app
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
|
||||
"""Test capture_flask_context captures user from Flask g object."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "user_123"
|
||||
mock_g._login_user = mock_user
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
assert ctx.user == mock_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
|
||||
"""Test capture_flask_context uses explicit user parameter."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
explicit_user = MagicMock()
|
||||
explicit_user.id = "user_456"
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context(user=explicit_user)
|
||||
|
||||
assert ctx.user == explicit_user
|
||||
|
||||
@patch("context.flask_app_context.current_app")
|
||||
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
|
||||
"""Test capture_flask_context captures context variables."""
|
||||
mock_app = MagicMock()
|
||||
mock_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
|
||||
|
||||
# Set a context variable
|
||||
test_var = contextvars.ContextVar("test_var")
|
||||
test_var.set("test_value")
|
||||
|
||||
from context.flask_app_context import capture_flask_context
|
||||
|
||||
ctx = capture_flask_context()
|
||||
|
||||
# Context variables should be captured
|
||||
assert ctx.context_vars is not None
|
||||
# Verify the variable is in the captured context
|
||||
captured_value = ctx.context_vars[test_var]
|
||||
assert captured_value == "test_value"
|
||||
|
||||
|
||||
class TestFlaskExecutionContextIntegration:
|
||||
"""Integration tests for FlaskExecutionContext."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_flask_app(self):
|
||||
"""Create a mock Flask app with proper app context."""
|
||||
app = MagicMock()
|
||||
app.config = {"TEST": "value"}
|
||||
app.extensions = {"db": MagicMock()}
|
||||
|
||||
# Mock app context
|
||||
mock_app_context = MagicMock()
|
||||
mock_app_context.__enter__ = MagicMock(return_value=None)
|
||||
mock_app_context.__exit__ = MagicMock(return_value=None)
|
||||
app.app_context.return_value = mock_app_context
|
||||
|
||||
return app
|
||||
|
||||
def test_enter_restores_context_vars(self, mock_flask_app):
|
||||
"""Test that enter restores captured context variables."""
|
||||
# Create a context variable and set a value
|
||||
test_var = contextvars.ContextVar("integration_test_var")
|
||||
test_var.set("original_value")
|
||||
|
||||
# Capture the context
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Change the value
|
||||
test_var.set("new_value")
|
||||
|
||||
# Create FlaskExecutionContext and enter it
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=context_vars,
|
||||
)
|
||||
|
||||
with ctx:
|
||||
# Value should be restored to original
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
# After exiting, variable stays at the value from within the context
|
||||
# (this is expected Python contextvars behavior)
|
||||
assert test_var.get() == "original_value"
|
||||
|
||||
def test_enter_enters_flask_app_context(self, mock_flask_app):
|
||||
"""Test that enter enters Flask app context."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
with ctx:
|
||||
# Verify app context was entered
|
||||
assert mock_flask_app.app_context.called
|
||||
|
||||
@patch("context.flask_app_context.g")
|
||||
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
|
||||
"""Test that enter restores user in Flask g object."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test_user"
|
||||
|
||||
# Note: FlaskExecutionContext saves user from g before entering context,
|
||||
# then restores it after entering the app context.
|
||||
# The user passed to constructor is NOT restored to g.
|
||||
# So we need to test the actual behavior.
|
||||
|
||||
# Create FlaskExecutionContext with user in constructor
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
user=mock_user,
|
||||
)
|
||||
|
||||
# Set user in g before entering (simulating existing user in g)
|
||||
mock_g._login_user = mock_user
|
||||
|
||||
with ctx:
|
||||
# After entering, the user from g before entry should be restored
|
||||
assert mock_g._login_user == mock_user
|
||||
|
||||
# The user in constructor is stored but not automatically restored to g
|
||||
# (it's available via ctx.user property)
|
||||
assert ctx.user == mock_user
|
||||
|
||||
def test_enter_method_as_context_manager(self, mock_flask_app):
|
||||
"""Test enter method returns a proper context manager."""
|
||||
from context.flask_app_context import FlaskExecutionContext
|
||||
|
||||
ctx = FlaskExecutionContext(
|
||||
flask_app=mock_flask_app,
|
||||
context_vars=contextvars.copy_context(),
|
||||
)
|
||||
|
||||
# enter() should return a generator/context manager
|
||||
with ctx.enter():
|
||||
# Should work without issues
|
||||
pass
|
||||
|
||||
# Verify app context was called
|
||||
assert mock_flask_app.app_context.called
|
||||
@ -25,12 +25,6 @@ class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNodeData:
|
||||
"""Simple node data stub with is_extractor_node property."""
|
||||
|
||||
is_extractor_node = False
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
@ -42,7 +36,6 @@ class _StubNode:
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
self.node_data = _StubNodeData()
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
|
||||
@ -1,174 +0,0 @@
|
||||
"""Tests for llm_utils module, specifically multimodal content handling."""
|
||||
|
||||
import string
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.workflow.nodes.llm.llm_utils import (
|
||||
_truncate_multimodal_content,
|
||||
build_context,
|
||||
restore_multimodal_content_in_messages,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncateMultimodalContent:
|
||||
"""Tests for _truncate_multimodal_content function."""
|
||||
|
||||
def test_returns_message_unchanged_for_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello, world!")
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content == "Hello, world!"
|
||||
|
||||
def test_returns_message_unchanged_for_none_content(self):
|
||||
"""None content should pass through unchanged."""
|
||||
message = UserPromptMessage(content=None)
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content is None
|
||||
|
||||
def test_clears_base64_when_file_ref_present(self):
|
||||
"""When file_ref is present, base64_data and url should be cleared."""
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=string.ascii_lowercase,
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
assert result_content.base64_data == ""
|
||||
assert result_content.url == ""
|
||||
# file_ref should be preserved
|
||||
assert result_content.file_ref == "local:test-file-id"
|
||||
|
||||
def test_truncates_base64_when_no_file_ref(self):
|
||||
"""When file_ref is missing (legacy), base64_data should be truncated."""
|
||||
long_base64 = "a" * 100
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=long_base64,
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref=None,
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
# Should be truncated with marker
|
||||
assert "...[TRUNCATED]..." in result_content.base64_data
|
||||
assert len(result_content.base64_data) < len(long_base64)
|
||||
|
||||
def test_preserves_text_content(self):
|
||||
"""Text content should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="test123",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[text_content, image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 2
|
||||
# Text content unchanged
|
||||
assert result.content[0].data == "Hello!"
|
||||
# Image content base64 cleared
|
||||
assert result.content[1].base64_data == ""
|
||||
|
||||
|
||||
class TestBuildContext:
|
||||
"""Tests for build_context function."""
|
||||
|
||||
def test_excludes_system_messages(self):
|
||||
"""System messages should be excluded from context."""
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello!"),
|
||||
]
|
||||
|
||||
context = build_context(messages, "Hi there!")
|
||||
|
||||
# Should have user message + assistant response, no system message
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
|
||||
def test_appends_assistant_response(self):
|
||||
"""Assistant response should be appended to context."""
|
||||
messages = [UserPromptMessage(content="What is 2+2?")]
|
||||
|
||||
context = build_context(messages, "The answer is 4.")
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[1].content == "The answer is 4."
|
||||
|
||||
|
||||
class TestRestoreMultimodalContentInMessages:
|
||||
"""Tests for restore_multimodal_content_in_messages function."""
|
||||
|
||||
@patch("core.file.file_manager.restore_multimodal_content")
|
||||
def test_restores_multimodal_content(self, mock_restore):
|
||||
"""Should restore multimodal content in messages."""
|
||||
# Setup mock
|
||||
restored_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="restored-base64",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
mock_restore.return_value = restored_content
|
||||
|
||||
# Create message with truncated content
|
||||
truncated_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
message = UserPromptMessage(content=[truncated_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].base64_data == "restored-base64"
|
||||
mock_restore.assert_called_once()
|
||||
|
||||
def test_passes_through_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello!")
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello!"
|
||||
|
||||
def test_passes_through_text_content(self):
|
||||
"""TextPromptMessageContent should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
message = UserPromptMessage(content=[text_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].data == "Hello!"
|
||||
@ -1,142 +0,0 @@
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.workspace_permission import (
|
||||
check_workspace_member_invite_permission,
|
||||
check_workspace_owner_transfer_permission,
|
||||
)
|
||||
|
||||
|
||||
class TestWorkspacePermissionHelper:
|
||||
"""Test workspace permission helper functions."""
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
|
||||
"""Community edition should always allow invitations without calling any service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# EnterpriseService should NOT be called in community edition
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
|
||||
"""Community edition should check billing plan but not call enterprise service."""
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block invitations when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = False
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should allow invitations when workspace policy is True."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_member_invite = True
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
# Enterprise service should NOT be called since billing plan already blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
|
||||
"""Enterprise edition should block transfer when workspace policy is False."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = False # Workspace policy blocks
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
@patch("libs.workspace_permission.FeatureService")
|
||||
def test_enterprise_allows_transfer_when_both_enabled(
|
||||
self, mock_feature_service, mock_config, mock_enterprise_service
|
||||
):
|
||||
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_features = Mock()
|
||||
mock_features.is_allow_transfer_workspace = True # Billing plan allows
|
||||
mock_feature_service.get_features.return_value = mock_features
|
||||
|
||||
mock_permission = Mock()
|
||||
mock_permission.allow_owner_transfer = True # Workspace policy allows
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
|
||||
|
||||
# Should not raise
|
||||
check_workspace_owner_transfer_permission("test-workspace-id")
|
||||
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
|
||||
|
||||
@patch("libs.workspace_permission.logger")
|
||||
@patch("libs.workspace_permission.EnterpriseService")
|
||||
@patch("libs.workspace_permission.dify_config")
|
||||
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
|
||||
"""On enterprise service error, should fail-open (allow) and log error."""
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Simulate enterprise service error
|
||||
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
|
||||
|
||||
# Should not raise (fail-open)
|
||||
check_workspace_member_invite_permission("test-workspace-id")
|
||||
|
||||
# Should log the error
|
||||
mock_logger.exception.assert_called_once()
|
||||
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)
|
||||
@ -2,13 +2,9 @@ from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import (
|
||||
invoke_llm_with_pydantic_model,
|
||||
invoke_llm_with_structured_output,
|
||||
)
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
@ -465,68 +461,3 @@ def test_model_specific_schema_preparation():
|
||||
|
||||
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
|
||||
assert "json_schema" in call_args.kwargs["model_parameters"]
|
||||
|
||||
|
||||
class ExampleOutput(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": "test"}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content="Return a JSON object with name.")]
|
||||
|
||||
result = invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
output_model=ExampleOutput,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"name": "test"}
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_streaming_rejected():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
def test_structured_output_with_pydantic_model_validation_error():
|
||||
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
|
||||
model_instance = get_model_instance()
|
||||
model_instance.invoke_llm.return_value = LLMResult(
|
||||
model="gpt-4o",
|
||||
message=AssistantPromptMessage(content='{"name": 123}'),
|
||||
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
|
||||
)
|
||||
|
||||
with pytest.raises(OutputParserError):
|
||||
invoke_llm_with_pydantic_model(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="test")],
|
||||
output_model=ExampleOutput,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
@ -203,7 +203,7 @@ const Annotation: FC<Props> = (props) => {
|
||||
</Filter>
|
||||
{isLoading
|
||||
? <Loading type="app" />
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-nested-conditional
|
||||
: total > 0
|
||||
? (
|
||||
<List
|
||||
|
||||
@ -134,6 +134,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
|
||||
},
|
||||
] as const
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-nested-template-literals, sonarjs/no-nested-conditional
|
||||
const [instructionFromSessionStorage, setInstruction] = useSessionStorageState<string>(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}`)
|
||||
const instruction = instructionFromSessionStorage || ''
|
||||
const [ideaOutput, setIdeaOutput] = useState<string>('')
|
||||
|
||||
@ -175,7 +175,7 @@ describe('SettingsModal', () => {
|
||||
renderSettingsModal()
|
||||
fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry'))
|
||||
const privacyInput = screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')
|
||||
|
||||
// eslint-disable-next-line sonarjs/no-clear-text-protocols
|
||||
fireEvent.change(privacyInput, { target: { value: 'ftp://invalid-url' } })
|
||||
|
||||
fireEvent.click(screen.getByText('common.operation.save'))
|
||||
|
||||
@ -14,6 +14,7 @@ import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import {
|
||||
|
||||
useAppTriggers,
|
||||
useInvalidateAppTriggers,
|
||||
useUpdateTriggerStatus,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user