Compare commits

..

1 Commits

Author SHA1 Message Date
fcb288c031 feat: add task type 2026-01-18 19:19:22 -08:00
236 changed files with 1475 additions and 24643 deletions

1
.gitignore vendored
View File

@ -209,7 +209,6 @@ api/.vscode
.history
.idea/
web/migration/
# pnpm
/.pnpm-store

View File

@ -71,8 +71,6 @@ def create_app() -> DifyApp:
def initialize_extensions(app: DifyApp):
# Initialize Flask context capture for workflow execution
from context.flask_app_context import init_flask_context
from extensions import (
ext_app_metrics,
ext_blueprints,
@ -102,8 +100,6 @@ def initialize_extensions(app: DifyApp):
ext_warnings,
)
init_flask_context()
extensions = [
ext_timezone,
ext_logging,

View File

@ -1,74 +0,0 @@
"""
Core Context - Framework-agnostic context management.
This module provides context management that is independent of any specific
web framework. Framework-specific implementations register their context
capture functions at application initialization time.
This ensures the workflow layer remains completely decoupled from Flask
or any other web framework.
"""
import contextvars
from collections.abc import Callable
from core.workflow.context.execution_context import (
ExecutionContext,
IExecutionContext,
NullAppContext,
)
# Global capturer function - set by framework-specific modules
_capturer: Callable[[], IExecutionContext] | None = None
def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None:
"""
Register a context capture function.
This should be called by framework-specific modules (e.g., Flask)
during application initialization.
Args:
capturer: Function that captures current context and returns IExecutionContext
"""
global _capturer
_capturer = capturer
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context.
This function uses the registered context capturer. If no capturer
is registered, it returns a minimal context with only contextvars
(suitable for non-framework environments like tests or standalone scripts).
Returns:
IExecutionContext with captured context
"""
if _capturer is None:
# No framework registered - return minimal context
return ExecutionContext(
app_context=NullAppContext(),
context_vars=contextvars.copy_context(),
)
return _capturer()
def reset_context_provider() -> None:
"""
Reset the context capturer.
This is primarily useful for testing to ensure a clean state.
"""
global _capturer
_capturer = None
__all__ = [
"capture_current_context",
"register_context_capturer",
"reset_context_provider",
]

View File

@ -1,198 +0,0 @@
"""
Flask App Context - Flask implementation of AppContext interface.
"""
import contextvars
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, final
from flask import Flask, current_app, g
from context import register_context_capturer
from core.workflow.context.execution_context import (
AppContext,
IExecutionContext,
)
@final
class FlaskAppContext(AppContext):
"""
Flask implementation of AppContext.
This adapts Flask's app context to the AppContext interface.
"""
def __init__(self, flask_app: Flask) -> None:
"""
Initialize Flask app context.
Args:
flask_app: The Flask application instance
"""
self._flask_app = flask_app
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value from Flask app config."""
return self._flask_app.config.get(key, default)
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name."""
return self._flask_app.extensions.get(name)
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter Flask app context."""
with self._flask_app.app_context():
yield
@property
def flask_app(self) -> Flask:
"""Get the underlying Flask app instance."""
return self._flask_app
def capture_flask_context(user: Any = None) -> IExecutionContext:
"""
Capture current Flask execution context.
This function captures the Flask app context and contextvars from the
current environment. It should be called from within a Flask request or
app context.
Args:
user: Optional user object to include in context
Returns:
IExecutionContext with captured Flask context
Raises:
RuntimeError: If called outside Flask context
"""
# Get Flask app instance
flask_app = current_app._get_current_object() # type: ignore
# Save current user if available
saved_user = user
if saved_user is None:
# Check for user in g (flask-login)
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Capture contextvars
context_vars = contextvars.copy_context()
return FlaskExecutionContext(
flask_app=flask_app,
context_vars=context_vars,
user=saved_user,
)
@final
class FlaskExecutionContext:
"""
Flask-specific execution context.
This is a specialized version of ExecutionContext that includes Flask app
context. It provides the same interface as ExecutionContext but with
Flask-specific implementation.
"""
def __init__(
self,
flask_app: Flask,
context_vars: contextvars.Context,
user: Any = None,
) -> None:
"""
Initialize Flask execution context.
Args:
flask_app: Flask application instance
context_vars: Python contextvars
user: Optional user object
"""
self._app_context = FlaskAppContext(flask_app)
self._context_vars = context_vars
self._user = user
self._flask_app = flask_app
@property
def app_context(self) -> FlaskAppContext:
"""Get Flask app context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context:
"""Get context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
return self._user
def __enter__(self) -> "FlaskExecutionContext":
"""Enter the Flask execution context."""
# Restore context variables
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
self._cm = self._app_context.enter()
self._cm.__enter__()
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
return self
def __exit__(self, *args: Any) -> None:
"""Exit the Flask execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter Flask execution context as context manager."""
# Restore context variables
for var, val in self._context_vars.items():
var.set(val)
# Save current user from g if available
saved_user = None
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context
with self._flask_app.app_context():
# Restore user in new app context
if saved_user is not None:
g._login_user = saved_user
yield
def init_flask_context() -> None:
"""
Initialize Flask context capture by registering the capturer.
This function should be called during Flask application initialization
to register the Flask-specific context capturer with the core context module.
Example:
app = Flask(__name__)
init_flask_context() # Register Flask context capturer
Note:
This function does not need the app instance as it uses Flask's
`current_app` to get the app when capturing context.
"""
register_context_capturer(capture_flask_context)

View File

@ -55,35 +55,6 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class ContextGeneratePayload(BaseModel):
"""Payload for generating extractor code node."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name to generate code for")
language: str = Field(default="python3", description="Code language (python3/javascript)")
prompt_messages: list[dict[str, Any]] = Field(
..., description="Multi-turn conversation history, last message is the current instruction"
)
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class SuggestedQuestionsPayload(BaseModel):
"""Payload for generating suggested questions."""
workflow_id: str = Field(..., description="Workflow ID")
node_id: str = Field(..., description="Current tool/llm node ID")
parameter_name: str = Field(..., description="Parameter name")
language: str = Field(
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
)
model_config_data: dict[str, Any] | None = Field(
default=None,
alias="model_config",
description="Model configuration (optional, uses system default if not provided)",
)
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@ -93,8 +64,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ContextGeneratePayload)
reg(SuggestedQuestionsPayload)
@console_ns.route("/rule-generate")
@ -309,74 +278,3 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
@console_ns.route("/context-generate")
class ContextGenerateApi(Resource):
@console_ns.doc("generate_with_context")
@console_ns.doc(description="Generate with multi-turn conversation context")
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
@console_ns.response(200, "Content generated successfully")
@console_ns.response(400, "Invalid request parameters or workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
from core.llm_generator.utils import deserialize_prompt_messages
args = ContextGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
try:
return LLMGenerator.generate_with_context(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
prompt_messages=prompt_messages,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
@console_ns.route("/context-generate/suggested-questions")
class SuggestedQuestionsApi(Resource):
@console_ns.doc("generate_suggested_questions")
@console_ns.doc(description="Generate suggested questions for context generation")
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
@console_ns.response(200, "Questions generated successfully")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
return LLMGenerator.generate_suggested_questions(
tenant_id=current_tenant_id,
workflow_id=args.workflow_id,
node_id=args.node_id,
parameter_name=args.parameter_name,
language=args.language,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)

View File

@ -46,8 +46,6 @@ from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.workflow.entities import MentionGraphRequest, MentionParameterSchema
from services.workflow.mention_graph_service import MentionGraphService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@ -190,15 +188,6 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
class MentionGraphPayload(BaseModel):
"""Request payload for generating mention graph."""
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
parameter_key: str = Field(description="Key of the parameter being extracted")
context_source: list[str] = Field(description="Variable selector for the context source")
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@ -216,7 +205,6 @@ reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
reg(MentionGraphPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@ -1178,54 +1166,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"status": "error",
}
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/mention-graph")
class MentionGraphApi(Resource):
"""
API for generating Mention LLM node graph structures.
This endpoint creates a complete graph structure containing an LLM node
configured to extract values from list[PromptMessage] variables.
"""
@console_ns.doc("generate_mention_graph")
@console_ns.doc(description="Generate a Mention LLM node graph structure")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[MentionGraphPayload.__name__])
@console_ns.response(200, "Mention graph generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Generate a Mention LLM node graph structure.
Returns a complete graph structure containing a single LLM node
configured for extracting values from list[PromptMessage] context.
"""
payload = MentionGraphPayload.model_validate(console_ns.payload or {})
parameter_schema = MentionParameterSchema(
name=payload.parameter_schema.get("name", payload.parameter_key),
type=payload.parameter_schema.get("type", "string"),
description=payload.parameter_schema.get("description", ""),
)
request = MentionGraphRequest(
parent_node_id=payload.parent_node_id,
parameter_key=payload.parameter_key,
context_source=payload.context_source,
parameter_schema=parameter_schema,
)
with Session(db.engine) as session:
service = MentionGraphService(session)
response = service.generate_mention_graph(tenant_id=app_model.tenant_id, request=request)
return response.model_dump()

View File

@ -17,7 +17,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
@ -58,8 +58,6 @@ def _convert_values_to_json_serializable_object(value: Segment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, ArrayPromptMessageSegment):
return value.to_object()
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:

View File

@ -69,13 +69,6 @@ class ActivateCheckApi(Resource):
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
# Check workspace permission
if tenant:
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
workspace_name = tenant.name if tenant else None
workspace_id = tenant.id if tenant else None
invitee_email = data.get("email") if data else None

View File

@ -107,12 +107,6 @@ class MemberInviteEmailApi(Resource):
inviter = current_user
if not inviter.current_tenant:
raise ValueError("No current tenant")
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(inviter.current_tenant.id)
invitation_results = []
console_web_url = dify_config.CONSOLE_WEB_URL

View File

@ -20,7 +20,6 @@ from controllers.console.error import AccountNotLinkTenantError
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
only_edition_enterprise,
setup_required,
)
from enums.cloud_plan import CloudPlan
@ -29,7 +28,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
from services.workspace_service import WorkspaceService
@ -290,31 +288,3 @@ class WorkspaceInfoApi(Resource):
db.session.commit()
return {"result": "success", "tenant": marshal(WorkspaceService.get_tenant_info(tenant), tenant_fields)}
@console_ns.route("/workspaces/current/permission")
class WorkspacePermissionApi(Resource):
"""Get workspace permissions for the current workspace."""
@setup_required
@login_required
@account_initialization_required
@only_edition_enterprise
def get(self):
"""
Get workspace permission settings.
Returns permission flags that control workspace features like member invitations and owner transfer.
"""
_, current_tenant_id = current_account_with_tenant()
if not current_tenant_id:
raise ValueError("No current tenant")
# Get workspace permissions from enterprise service
permission = EnterpriseService.WorkspacePermissionService.get_permission(current_tenant_id)
return {
"workspace_id": permission.workspace_id,
"allow_member_invite": permission.allow_member_invite,
"allow_owner_transfer": permission.allow_owner_transfer,
}, 200

View File

@ -286,12 +286,13 @@ def enable_change_email(view: Callable[P, R]):
def is_allow_transfer_owner(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
from libs.workspace_permission import check_workspace_owner_transfer_permission
_, current_tenant_id = current_account_with_tenant()
# Check both billing/plan level and workspace policy level permissions
check_workspace_owner_transfer_permission(current_tenant_id)
return view(*args, **kwargs)
features = FeatureService.get_features(current_tenant_id)
if features.is_allow_transfer_workspace:
return view(*args, **kwargs)
# otherwise, return 403
abort(403)
return decorated

View File

@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
response_chunk.update(sub_stream_response_dict)
@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -70,8 +70,6 @@ class _NodeSnapshot:
"""Empty string means the node is not executing inside an iteration."""
loop_id: str = ""
"""Empty string means the node is not executing inside a loop."""
mention_parent_id: str = ""
"""Empty string means the node is not an extractor node."""
class WorkflowResponseConverter:
@ -133,7 +131,6 @@ class WorkflowResponseConverter:
start_at=event.start_at,
iteration_id=event.in_iteration_id or "",
loop_id=event.in_loop_id or "",
mention_parent_id=event.in_mention_parent_id or "",
)
node_execution_id = NodeExecutionId(event.node_execution_id)
self._node_snapshots[node_execution_id] = snapshot
@ -290,7 +287,6 @@ class WorkflowResponseConverter:
created_at=int(snapshot.start_at.timestamp()),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy,
),
)
@ -377,7 +373,6 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
),
)
@ -427,7 +422,6 @@ class WorkflowResponseConverter:
files=self.fetch_files_from_node_outputs(event.outputs or {}),
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
retry_index=event.retry_index,
),
)

View File

@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
}
if isinstance(sub_stream_response, MessageEndStreamResponse):
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
metadata = sub_stream_response_dict.get("metadata", {})
if not isinstance(metadata, dict):
metadata = {}
@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(cast(dict, data))
else:
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk
@classmethod
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
else:
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
response_chunk.update(sub_stream_response.model_dump())
yield response_chunk

View File

@ -8,7 +8,7 @@ from typing import Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@ -23,7 +23,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
@ -477,7 +476,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:return:
"""
with preserve_flask_contexts(flask_app, context_vars=context):
with session_factory.create_session() as session:
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,

View File

@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
data = cls._error_to_stream_response(sub_stream_response.err)
response_chunk.update(data)
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk
@classmethod
@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
else:
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
response_chunk.update(sub_stream_response.model_dump(mode="json"))
yield response_chunk

View File

@ -385,7 +385,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
inputs=inputs,
process_data=process_data,
outputs=outputs,
@ -406,7 +405,6 @@ class WorkflowBasedAppRunner:
start_at=event.start_at,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
agent_strategy=event.agent_strategy,
provider_type=event.provider_type,
provider_id=event.provider_id,
@ -430,7 +428,6 @@ class WorkflowBasedAppRunner:
execution_metadata=execution_metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunFailedEvent):
@ -447,7 +444,6 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunExceptionEvent):
@ -464,7 +460,6 @@ class WorkflowBasedAppRunner:
execution_metadata=event.node_run_result.metadata,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
@ -474,7 +469,6 @@ class WorkflowBasedAppRunner:
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):
@ -483,7 +477,6 @@ class WorkflowBasedAppRunner:
retriever_resources=event.retriever_resources,
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
in_mention_parent_id=event.in_mention_parent_id,
)
)
elif isinstance(event, NodeRunAgentLogEvent):

View File

@ -190,8 +190,6 @@ class QueueTextChunkEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAgentMessageEvent(AppQueueEvent):
@ -231,8 +229,6 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
class QueueAnnotationReplyEvent(AppQueueEvent):
@ -310,8 +306,6 @@ class QueueNodeStartedEvent(AppQueueEvent):
node_run_index: int = 1 # FIXME(-LAN-): may not used
in_iteration_id: str | None = None
in_loop_id: str | None = None
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
agent_strategy: AgentNodeStrategyInit | None = None
@ -334,8 +328,6 @@ class QueueNodeSucceededEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)
@ -391,8 +383,6 @@ class QueueNodeExceptionEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)
@ -417,8 +407,6 @@ class QueueNodeFailedEvent(AppQueueEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""parent node id if this is an extractor node event"""
start_at: datetime
inputs: Mapping[str, object] = Field(default_factory=dict)

View File

@ -262,7 +262,6 @@ class NodeStartStreamResponse(StreamResponse):
extras: dict[str, object] = Field(default_factory=dict)
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
event: StreamEvent = StreamEvent.NODE_STARTED
@ -286,7 +285,6 @@ class NodeStartStreamResponse(StreamResponse):
"extras": {},
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
},
}
@ -322,7 +320,6 @@ class NodeFinishStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
event: StreamEvent = StreamEvent.NODE_FINISHED
workflow_run_id: str
@ -352,7 +349,6 @@ class NodeFinishStreamResponse(StreamResponse):
"files": [],
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
},
}
@ -388,7 +384,6 @@ class NodeRetryStreamResponse(StreamResponse):
files: Sequence[Mapping[str, Any]] | None = []
iteration_id: str | None = None
loop_id: str | None = None
mention_parent_id: str | None = None
retry_index: int = 0
event: StreamEvent = StreamEvent.NODE_RETRY
@ -419,7 +414,6 @@ class NodeRetryStreamResponse(StreamResponse):
"files": [],
"iteration_id": self.data.iteration_id,
"loop_id": self.data.loop_id,
"mention_parent_id": self.data.mention_parent_id,
"retry_index": self.data.retry_index,
},
}

View File

@ -1,5 +1,4 @@
import base64
import logging
from collections.abc import Mapping
from configs import dify_config
@ -11,10 +10,7 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import (
MultiModalPromptMessageContent,
PromptMessageContentUnionTypes,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
@ -22,8 +18,6 @@ from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
@ -95,8 +89,6 @@ def to_prompt_message_content(
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
"file_ref": _encode_file_ref(f),
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
@ -104,17 +96,6 @@ def to_prompt_message_content(
return prompt_class_map[f.type].model_validate(params)
def _encode_file_ref(f: File) -> str | None:
"""Encode file reference as 'transfer_method:id_or_url' string."""
if f.transfer_method == FileTransferMethod.REMOTE_URL:
return f"remote:{f.remote_url}" if f.remote_url else None
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
return f"local:{f.related_id}" if f.related_id else None
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
return f"tool:{f.related_id}" if f.related_id else None
return None
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
@ -183,128 +164,3 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def restore_multimodal_content(
content: MultiModalPromptMessageContent,
) -> MultiModalPromptMessageContent:
"""
Restore base64_data or url for multimodal content from file_ref.
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
Args:
content: MultiModalPromptMessageContent with file_ref field
Returns:
MultiModalPromptMessageContent with restored base64_data or url
"""
# Skip if no file reference or content already has data
if not content.file_ref:
return content
if content.base64_data or content.url:
return content
try:
file = _build_file_from_ref(
file_ref=content.file_ref,
file_format=content.format,
mime_type=content.mime_type,
filename=content.filename,
)
if not file:
return content
# Restore content based on config
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
restored_base64 = _get_encoded_string(file)
return content.model_copy(update={"base64_data": restored_base64})
else:
restored_url = _to_url(file)
return content.model_copy(update={"url": restored_url})
except Exception as e:
logger.warning("Failed to restore multimodal content: %s", e)
return content
def _build_file_from_ref(
file_ref: str,
file_format: str | None,
mime_type: str | None,
filename: str | None,
) -> File | None:
"""
Build a File object from encoded file_ref string.
Args:
file_ref: Encoded reference "transfer_method:id_or_url"
file_format: The file format/extension (without dot)
mime_type: The mime type
filename: The filename
Returns:
File object with storage_key loaded, or None if not found
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import UploadFile
from models.tools import ToolFile
# Parse file_ref: "method:value"
if ":" not in file_ref:
logger.warning("Invalid file_ref format: %s", file_ref)
return None
method, value = file_ref.split(":", 1)
extension = f".{file_format}" if file_format else None
if method == "remote":
return File(
tenant_id="",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=value,
extension=extension,
mime_type=mime_type,
filename=filename,
storage_key="",
)
# Query database for storage_key
with Session(db.engine) as session:
if method == "local":
stmt = select(UploadFile).where(UploadFile.id == value)
upload_file = session.scalar(stmt)
if upload_file:
return File(
tenant_id=upload_file.tenant_id,
type=FileType(upload_file.extension)
if hasattr(FileType, upload_file.extension.upper())
else FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=value,
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
mime_type=mime_type or upload_file.mime_type,
filename=filename or upload_file.name,
storage_key=upload_file.key,
)
elif method == "tool":
stmt = select(ToolFile).where(ToolFile.id == value)
tool_file = session.scalar(stmt)
if tool_file:
return File(
tenant_id=tool_file.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=value,
extension=extension,
mime_type=mime_type or tool_file.mimetype,
filename=filename or tool_file.name,
storage_key=tool_file.file_key,
)
logger.warning("File not found for file_ref: %s", file_ref)
return None

View File

@ -1,16 +1,11 @@
import json
import logging
import re
from collections.abc import Mapping, Sequence
from typing import Any, Protocol, cast
from collections.abc import Sequence
from typing import Protocol, cast
import json_repair
from core.llm_generator.output_models import (
CodeNodeStructuredOutput,
InstructionModifyOutput,
SuggestedQuestionsOutput,
)
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@ -398,432 +393,6 @@ class LLMGenerator:
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def generate_with_context(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
prompt_messages: list[PromptMessage],
model_config: dict,
) -> dict:
"""
Generate extractor code node based on conversation context.
Args:
tenant_id: Tenant/workspace ID
workflow_id: Workflow ID
node_id: Current tool/llm node ID
parameter_name: Parameter name to generate code for
language: Code language (python3/javascript)
prompt_messages: Multi-turn conversation history (last message is instruction)
model_config: Model configuration (provider, name, completion_params)
Returns:
dict with CodeNodeData format:
- variables: Input variable selectors
- code_language: Code language
- code: Generated code
- outputs: Output definitions
- message: Explanation
- error: Error message if any
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from services.workflow_service import WorkflowService
# Get workflow
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return cls._error_response(f"App {workflow_id} not found")
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return cls._error_response(f"Workflow for app {workflow_id} not found")
# Get upstream nodes via edge backtracking
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
# Get current node info
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return cls._error_response(f"Node {node_id} not found")
# Get parameter info
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# Build system prompt
system_prompt = cls._build_extractor_system_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
parameter_info=parameter_info,
language=language,
)
# Construct complete prompt_messages with system prompt
complete_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
*prompt_messages,
]
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
# Get model instance and schema
provider = model_config.get("provider", "")
model_name = model_config.get("name", "")
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=provider,
model=model_name,
)
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return cls._error_response(f"Model schema not found for {model_name}")
model_parameters = model_config.get("completion_params", {})
try:
response = invoke_llm_with_pydantic_model(
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=complete_messages,
output_model=CodeNodeStructuredOutput,
model_parameters=model_parameters,
stream=False,
tenant_id=tenant_id,
)
return cls._parse_code_node_output(
response.structured_output, language, parameter_info.get("type", "string")
)
except InvokeError as e:
return cls._error_response(str(e))
except Exception as e:
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
return cls._error_response(f"An unexpected error occurred: {str(e)}")
@classmethod
def _error_response(cls, error: str) -> dict:
"""Return error response in CodeNodeData format."""
return {
"variables": [],
"code_language": "python3",
"code": "",
"outputs": {},
"message": "",
"error": error,
}
@classmethod
def generate_suggested_questions(
cls,
tenant_id: str,
workflow_id: str,
node_id: str,
parameter_name: str,
language: str,
model_config: dict | None = None,
) -> dict:
"""
Generate suggested questions for context generation.
Returns dict with questions array and error field.
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
from services.workflow_service import WorkflowService
# Get workflow context (reuse existing logic)
with Session(db.engine) as session:
stmt = select(App).where(App.id == workflow_id)
app = session.scalar(stmt)
if not app:
return {"questions": [], "error": f"App {workflow_id} not found"}
workflow = WorkflowService().get_draft_workflow(app_model=app)
if not workflow:
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
if not current_node:
return {"questions": [], "error": f"Node {node_id} not found"}
parameter_info = cls._get_parameter_info(
tenant_id=tenant_id,
node_data=current_node.get("data", {}),
parameter_name=parameter_name,
)
# Build prompt
system_prompt = cls._build_suggested_questions_prompt(
upstream_nodes=upstream_nodes,
current_node=current_node,
parameter_info=parameter_info,
language=language,
)
prompt_messages: list[PromptMessage] = [
SystemPromptMessage(content=system_prompt),
]
# Get model instance - use default if model_config not provided
model_manager = ModelManager()
if model_config:
provider = model_config.get("provider", "")
model_name = model_config.get("name", "")
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=provider,
model=model_name,
)
else:
model_instance = model_manager.get_default_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
)
model_name = model_instance.model
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return {"questions": [], "error": f"Model schema not found for {model_name}"}
completion_params = model_config.get("completion_params", {}) if model_config else {}
model_parameters = {**completion_params, "max_tokens": 256}
try:
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
output_model=SuggestedQuestionsOutput,
model_parameters=model_parameters,
stream=False,
tenant_id=tenant_id,
)
questions = response.structured_output.get("questions", []) if response.structured_output else []
return {"questions": questions, "error": ""}
except InvokeError as e:
return {"questions": [], "error": str(e)}
except Exception as e:
logger.exception("Failed to generate suggested questions, model: %s", model_name)
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
@classmethod
def _build_suggested_questions_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
language: str = "English",
) -> str:
"""Build minimal prompt for suggested questions generation."""
# Simplify upstream nodes to reduce tokens
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
param_type = parameter_info.get("type", "string")
param_desc = parameter_info.get("description", "")[:100]
return f"""Suggest 3 code generation questions for extracting data.
Sources: {", ".join(sources)}
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
Output 3 short, practical questions in {language}."""
@classmethod
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
"""
Get all upstream nodes via edge backtracking.
Traverses the graph backwards from node_id to collect all reachable nodes.
"""
from collections import defaultdict
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
edges = graph_dict.get("edges", [])
# Build reverse adjacency list
reverse_adj: dict[str, list[str]] = defaultdict(list)
for edge in edges:
reverse_adj[edge["target"]].append(edge["source"])
# BFS to find all upstream nodes
visited: set[str] = set()
queue = [node_id]
upstream: list[dict] = []
while queue:
current = queue.pop(0)
for source in reverse_adj.get(current, []):
if source not in visited:
visited.add(source)
queue.append(source)
if source in nodes:
upstream.append(cls._extract_node_info(nodes[source]))
return upstream
@classmethod
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
"""Get node by ID from graph."""
for node in graph_dict.get("nodes", []):
if node["id"] == node_id:
return node
return None
@classmethod
def _extract_node_info(cls, node: dict) -> dict:
"""Extract minimal node info with outputs based on node type."""
node_type = node["data"]["type"]
node_data = node.get("data", {})
# Build outputs based on node type (only type, no description to reduce tokens)
outputs: dict[str, str] = {}
match node_type:
case "start":
for var in node_data.get("variables", []):
name = var.get("variable", var.get("name", ""))
outputs[name] = var.get("type", "string")
case "llm":
outputs["text"] = "string"
case "code":
for name, output in node_data.get("outputs", {}).items():
outputs[name] = output.get("type", "string")
case "http-request":
outputs = {"body": "string", "status_code": "number", "headers": "object"}
case "knowledge-retrieval":
outputs["result"] = "array[object]"
case "tool":
outputs = {"text": "string", "json": "object"}
case _:
outputs["output"] = "string"
info: dict = {
"id": node["id"],
"title": node_data.get("title", node["id"]),
"outputs": outputs,
}
# Only include description if not empty
desc = node_data.get("desc", "")
if desc:
info["desc"] = desc
return info
@classmethod
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
"""Get parameter info from tool schema using ToolManager."""
default_info = {"name": parameter_name, "type": "string", "description": ""}
if node_data.get("type") != "tool":
return default_info
try:
from core.app.entities.app_invoke_entities import InvokeFrom
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
provider_type_str = node_data.get("provider_type", "")
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
tool_runtime = ToolManager.get_tool_runtime(
provider_type=provider_type,
provider_id=node_data.get("provider_id", ""),
tool_name=node_data.get("tool_name", ""),
tenant_id=tenant_id,
invoke_from=InvokeFrom.DEBUGGER,
)
parameters = tool_runtime.get_merged_runtime_parameters()
for param in parameters:
if param.name == parameter_name:
return {
"name": param.name,
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
"description": param.llm_description
or (param.human_description.en_US if param.human_description else ""),
"required": param.required,
}
except Exception as e:
logger.debug("Failed to get parameter info from ToolManager: %s", e)
return default_info
@classmethod
def _build_extractor_system_prompt(
cls,
upstream_nodes: list[dict],
current_node: dict,
parameter_info: dict,
language: str,
) -> str:
"""Build system prompt for extractor code generation."""
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
param_type = parameter_info.get("type", "string")
return f"""You are a code generator for workflow automation.
Generate {language} code to extract/transform upstream node outputs for the target parameter.
## Upstream Nodes
{upstream_json}
## Target
Node: {current_node["data"].get("title", current_node["id"])}
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
## Requirements
- Write a main function that returns type: {param_type}
- Use value_selector format: ["node_id", "output_name"]
"""
@classmethod
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
"""
Parse structured output to CodeNodeData format.
Args:
content: Structured output dict from invoke_llm_with_structured_output
language: Code language
parameter_type: Expected parameter type
Returns dict with variables, code_language, code, outputs, message, error.
"""
if content is None:
return cls._error_response("Empty or invalid response from LLM")
# Validate and normalize variables
variables = [
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
for v in content.get("variables", [])
if isinstance(v, dict)
]
outputs = content.get("outputs", {"result": {"type": parameter_type}})
return {
"variables": variables,
"code_language": language,
"code": content.get("code", ""),
"outputs": outputs,
"message": content.get("explanation", ""),
"error": "",
}
@staticmethod
def instruction_modify_legacy(
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
@ -960,10 +529,6 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
model_name = model_config.get("name", "")
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
if not model_schema:
return {"error": f"Model schema not found for {model_name}"}
match node_type:
case "llm" | "agent":
system_prompt = LLM_MODIFY_PROMPT_SYSTEM
@ -987,18 +552,20 @@ Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("de
model_parameters = {"temperature": 0.4}
try:
from core.llm_generator.output_parser.structured_output import invoke_llm_with_pydantic_model
response = invoke_llm_with_pydantic_model(
provider=model_instance.provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=list(prompt_messages),
output_model=InstructionModifyOutput,
model_parameters=model_parameters,
stream=False,
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
return response.structured_output or {}
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}

View File

@ -1,34 +0,0 @@
from __future__ import annotations
from pydantic import BaseModel, ConfigDict, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base.entities import VariableSelector
class SuggestedQuestionsOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
questions: list[str] = Field(min_length=3, max_length=3)
class CodeNodeOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
type: SegmentType
class CodeNodeStructuredOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
variables: list[VariableSelector]
code: str
outputs: dict[str, CodeNodeOutput]
explanation: str
class InstructionModifyOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
modified: str
message: str

View File

@ -1,188 +0,0 @@
"""
File reference detection and conversion for structured output.
This module provides utilities to:
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
2. Convert file ID strings to File objects after LLM returns
"""
import uuid
from collections.abc import Mapping
from typing import Any
from core.file import File
from core.variables.segments import ArrayFileSegment, FileSegment
from factories.file_factory import build_from_mapping
FILE_REF_FORMAT = "dify-file-ref"
def is_file_ref_property(schema: dict) -> bool:
"""Check if a schema property is a file reference."""
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
"""
Recursively detect file reference fields in schema.
Args:
schema: JSON Schema to analyze
path: Current path in the schema (used for recursion)
Returns:
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
"""
file_ref_paths: list[str] = []
schema_type = schema.get("type")
if schema_type == "object":
for prop_name, prop_schema in schema.get("properties", {}).items():
current_path = f"{path}.{prop_name}" if path else prop_name
if is_file_ref_property(prop_schema):
file_ref_paths.append(current_path)
elif isinstance(prop_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
elif schema_type == "array":
items_schema = schema.get("items", {})
array_path = f"{path}[*]" if path else "[*]"
if is_file_ref_property(items_schema):
file_ref_paths.append(array_path)
elif isinstance(items_schema, dict):
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
return file_ref_paths
def convert_file_refs_in_output(
output: Mapping[str, Any],
json_schema: Mapping[str, Any],
tenant_id: str,
) -> dict[str, Any]:
"""
Convert file ID strings to File objects based on schema.
Args:
output: The structured_output from LLM result
json_schema: The original JSON schema (to detect file ref fields)
tenant_id: Tenant ID for file lookup
Returns:
Output with file references converted to File objects
"""
file_ref_paths = detect_file_ref_fields(json_schema)
if not file_ref_paths:
return dict(output)
result = _deep_copy_dict(output)
for path in file_ref_paths:
_convert_path_in_place(result, path.split("."), tenant_id)
return result
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
"""Deep copy a mapping to a mutable dict."""
result: dict[str, Any] = {}
for key, value in obj.items():
if isinstance(value, Mapping):
result[key] = _deep_copy_dict(value)
elif isinstance(value, list):
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
else:
result[key] = value
return result
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
"""Convert file refs at the given path in place, wrapping in Segment types."""
if not path_parts:
return
current = path_parts[0]
remaining = path_parts[1:]
# Handle array notation like "files[*]"
if current.endswith("[*]"):
key = current[:-3] if current != "[*]" else None
target = obj.get(key) if key else obj
if isinstance(target, list):
if remaining:
# Nested array with remaining path - recurse into each item
for item in target:
if isinstance(item, dict):
_convert_path_in_place(item, remaining, tenant_id)
else:
# Array of file IDs - convert all and wrap in ArrayFileSegment
files: list[File] = []
for item in target:
file = _convert_file_id(item, tenant_id)
if file is not None:
files.append(file)
# Replace the array with ArrayFileSegment
if key:
obj[key] = ArrayFileSegment(value=files)
return
if not remaining:
# Leaf node - convert the value and wrap in FileSegment
if current in obj:
file = _convert_file_id(obj[current], tenant_id)
if file is not None:
obj[current] = FileSegment(value=file)
else:
obj[current] = None
else:
# Recurse into nested object
if current in obj and isinstance(obj[current], dict):
_convert_path_in_place(obj[current], remaining, tenant_id)
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
"""
Convert a file ID string to a File object.
Tries multiple file sources in order:
1. ToolFile (files generated by tools/workflows)
2. UploadFile (files uploaded by users)
"""
if not isinstance(file_id, str):
return None
# Validate UUID format
try:
uuid.UUID(file_id)
except ValueError:
return None
# Try ToolFile first (files generated by tools/workflows)
try:
return build_from_mapping(
mapping={
"transfer_method": "tool_file",
"tool_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# Try UploadFile (files uploaded by users)
try:
return build_from_mapping(
mapping={
"transfer_method": "local_file",
"upload_file_id": file_id,
},
tenant_id=tenant_id,
)
except ValueError:
pass
# File not found in any source
return None

View File

@ -2,13 +2,12 @@ import json
from collections.abc import Generator, Mapping, Sequence
from copy import deepcopy
from enum import StrEnum
from typing import Any, Literal, TypeVar, cast, overload
from typing import Any, Literal, cast, overload
import json_repair
from pydantic import BaseModel, TypeAdapter, ValidationError
from pydantic import TypeAdapter, ValidationError
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
from core.model_manager import ModelInstance
from core.model_runtime.callbacks.base_callback import Callback
@ -44,9 +43,6 @@ class SpecialModelType(StrEnum):
OLLAMA = "ollama"
T = TypeVar("T", bound=BaseModel)
@overload
def invoke_llm_with_structured_output(
*,
@ -61,7 +57,6 @@ def invoke_llm_with_structured_output(
stream: Literal[True],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
@overload
def invoke_llm_with_structured_output(
@ -77,7 +72,6 @@ def invoke_llm_with_structured_output(
stream: Literal[False],
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
@overload
def invoke_llm_with_structured_output(
@ -93,7 +87,6 @@ def invoke_llm_with_structured_output(
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
def invoke_llm_with_structured_output(
*,
@ -108,30 +101,23 @@ def invoke_llm_with_structured_output(
stream: bool = True,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
"""
Invoke large language model with structured output.
Invoke large language model with structured output
1. This method invokes model_instance.invoke_llm with json_schema
2. Try to parse the result as structured output
This method invokes model_instance.invoke_llm with json_schema and parses
the result as structured output.
:param provider: model provider name
:param model_schema: model schema entity
:param model_instance: model instance to invoke
:param prompt_messages: prompt messages
:param json_schema: json schema for structured output
:param json_schema: json schema
:param model_parameters: model parameters
:param tools: tools for tool calling
:param stop: stop words
:param stream: is stream response
:param user: unique user id
:param callbacks: callbacks
:param tenant_id: tenant ID for file reference conversion. When provided and
json_schema contains file reference fields (format: "dify-file-ref"),
file IDs in the output will be automatically converted to File objects.
:return: full response or stream response chunk generator result
"""
# handle native json schema
model_parameters_with_json_schema: dict[str, Any] = {
**(model_parameters or {}),
@ -167,18 +153,8 @@ def invoke_llm_with_structured_output(
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
)
structured_output = _parse_structured_output(llm_result.message.content)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
return LLMResultWithStructuredOutput(
structured_output=structured_output,
structured_output=_parse_structured_output(llm_result.message.content),
model=llm_result.model,
message=llm_result.message,
usage=llm_result.usage,
@ -210,18 +186,8 @@ def invoke_llm_with_structured_output(
delta=event.delta,
)
structured_output = _parse_structured_output(result_text)
# Convert file references if tenant_id is provided
if tenant_id is not None:
structured_output = convert_file_refs_in_output(
output=structured_output,
json_schema=json_schema,
tenant_id=tenant_id,
)
yield LLMResultChunkWithStructuredOutput(
structured_output=structured_output,
structured_output=_parse_structured_output(result_text),
model=model_schema.model,
prompt_messages=prompt_messages,
system_fingerprint=system_fingerprint,
@ -236,87 +202,6 @@ def invoke_llm_with_structured_output(
return generator()
@overload
def invoke_llm_with_pydantic_model(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: Literal[False] = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput: ...
def invoke_llm_with_pydantic_model(
*,
provider: str,
model_schema: AIModelEntity,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
output_model: type[T],
model_parameters: Mapping | None = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: list[str] | None = None,
stream: bool = False,
user: str | None = None,
callbacks: list[Callback] | None = None,
tenant_id: str | None = None,
) -> LLMResultWithStructuredOutput:
"""
Invoke large language model with a Pydantic output model.
This helper generates a JSON schema from the Pydantic model, invokes the
structured-output LLM path, and validates the result in non-streaming mode.
"""
if stream:
raise ValueError("invoke_llm_with_pydantic_model only supports stream=False")
json_schema = _schema_from_pydantic(output_model)
result = invoke_llm_with_structured_output(
provider=provider,
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
json_schema=json_schema,
model_parameters=model_parameters,
tools=tools,
stop=stop,
stream=False,
user=user,
callbacks=callbacks,
tenant_id=tenant_id,
)
structured_output = result.structured_output
if structured_output is None:
raise OutputParserError("Structured output is empty")
validated_output = _validate_structured_output(output_model, structured_output)
return result.model_copy(update={"structured_output": validated_output})
def _schema_from_pydantic(output_model: type[BaseModel]) -> dict[str, Any]:
return output_model.model_json_schema()
def _validate_structured_output(
output_model: type[T],
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
try:
validated_output = output_model.model_validate(structured_output)
except ValidationError as exc:
raise OutputParserError(f"Structured output validation failed: {exc}") from exc
return validated_output.model_dump(mode="python")
def _handle_native_json_schema(
provider: str,
model_schema: AIModelEntity,

View File

@ -1,45 +0,0 @@
"""Utility functions for LLM generator."""
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
"""
Deserialize list of dicts to list[PromptMessage].
Expected format:
[
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."},
]
"""
result: list[PromptMessage] = []
for msg in messages:
role = PromptMessageRole.value_of(msg["role"])
content = msg.get("content", "")
match role:
case PromptMessageRole.USER:
result.append(UserPromptMessage(content=content))
case PromptMessageRole.ASSISTANT:
result.append(AssistantPromptMessage(content=content))
case PromptMessageRole.SYSTEM:
result.append(SystemPromptMessage(content=content))
case PromptMessageRole.TOOL:
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
return result
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
"""
Serialize list[PromptMessage] to list of dicts.
"""
return [{"role": msg.role.value, "content": msg.content} for msg in messages]

View File

@ -1,267 +0,0 @@
# Memory Module
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
## Overview
The memory module contains two types of memory implementations:
1. **TokenBufferMemory** - Conversation-level memory (existing)
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
```
┌─────────────────────────────────────────────────────────────────────────────┐
│ Memory Architecture │
├─────────────────────────────────────────────────────────────────────────────┤
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ TokenBufferMemory │ │
│ │ Scope: Conversation │ │
│ │ Storage: Database (Message table) │ │
│ │ Key: conversation_id │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ NodeTokenBufferMemory │ │
│ │ Scope: Node within Conversation │ │
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
```
---
## TokenBufferMemory (Existing)
### Purpose
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
### Key Features
- **Conversation-scoped**: All messages within a conversation are candidates
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
- **Token-limited**: Truncates history to fit within `max_token_limit`
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
### Data Flow
```
Message Table TokenBufferMemory LLM
│ │ │
│ SELECT * FROM messages │ │
│ WHERE conversation_id = ? │ │
│ ORDER BY created_at DESC │ │
├─────────────────────────────────▶│ │
│ │ │
│ extract_thread_messages() │
│ │ │
│ build_prompt_message_with_files() │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage]
│ ├───────────────────────▶│
│ │ │
```
### Thread Extraction
When a user regenerates a response, a new thread is created:
```
Message A (user)
└── Message A' (assistant)
└── Message B (user)
└── Message B' (assistant)
└── Message A'' (assistant, regenerated) ← New thread
└── Message C (user)
└── Message C' (assistant)
```
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
### Usage
```python
from core.memory.token_buffer_memory import TokenBufferMemory
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
```
---
## NodeTokenBufferMemory
### Purpose
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
### Use Cases
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
3. **Specialized Agents**: Each agent node maintains its own dialogue history
### Design: Zero Extra Storage
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
Each LLM node execution outputs:
```python
outputs = {
"text": clean_text,
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
...
}
```
This `outputs["context"]` contains:
- All previous user/assistant messages (excluding system prompt)
- The current assistant response
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
### Benefits
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|--------|----------------------------|--------------------------------|
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
| Migration | Required | None |
| Complexity | High | Low |
### Data Flow
```
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ │ │
│ SELECT outputs FROM │ │
│ workflow_node_executions │ │
│ WHERE workflow_run_id = ? │ │
│ AND node_id = ? │ │
│◀─────────────────────────────────┤ │
│ │ │
│ outputs["context"] │ │
├─────────────────────────────────▶│ │
│ │ │
│ deserialize PromptMessages │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
```
### Thread Tracking
Thread extraction still uses `Message` table's `parent_message_id` structure:
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
2. Get the last completed `workflow_run_id` in the thread
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
### API
```python
class NodeTokenBufferMemory:
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
"""Initialize node-level memory."""
...
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
Reads from last completed execution's outputs["context"].
"""
...
# Legacy methods (no-op, kept for compatibility)
def add_messages(self, *args, **kwargs) -> None: pass
def flush(self) -> None: pass
def clear(self) -> None: pass
```
### Configuration
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
```python
class MemoryMode(StrEnum):
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: MemoryWindowConfig | None = None
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION
```
**Mode Behavior:**
| Mode | Memory Class | Scope | Availability |
| -------------- | --------------------- | ------------------------ | ------------- |
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
---
## Comparison
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ---------------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
---
## Extending to Other Nodes
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
2. The `NodeTokenBufferMemory` will automatically pick it up
Nodes that could potentially support this:
- `question_classifier`
- `parameter_extractor`
- `agent`
---
## Future Considerations
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
2. **Compression**: For very long conversations, consider summarization strategies
3. **Extension**: Other nodes may benefit from node-level memory

View File

@ -1,11 +0,0 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

View File

@ -1,83 +0,0 @@
"""
Base memory interfaces and types.
This module defines the common protocol for memory implementations.
"""
from abc import ABC, abstractmethod
from collections.abc import Sequence
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
class BaseMemory(ABC):
"""
Abstract base class for memory implementations.
Provides a common interface for both conversation-level and node-level memory.
"""
@abstractmethod
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Sequence of PromptMessage for LLM context
"""
pass
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt as formatted text.
:param human_prefix: Prefix for human messages
:param ai_prefix: Prefix for assistant messages
:param max_token_limit: Maximum tokens for history
:param message_limit: Maximum number of messages
:return: Formatted history text
"""
from core.model_runtime.entities import (
PromptMessageRole,
TextPromptMessageContent,
)
prompt_messages = self.get_history_prompt_messages(
max_token_limit=max_token_limit,
message_limit=message_limit,
)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@ -1,197 +0,0 @@
"""
Node-level Token Buffer Memory for Chatflow.
This module provides node-scoped memory within a conversation.
Each LLM node in a workflow can maintain its own independent conversation history.
Note: This is only available in Chatflow (advanced-chat mode) because it requires
both conversation_id and node_id.
Design:
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
- No separate storage needed - the context is already saved during node execution
- Thread tracking leverages Message table's parent_message_id structure
"""
import logging
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history.
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
which is already saved during node execution. No separate storage needed.
"""
def __init__(
self,
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(Message)
.where(Message.conversation_id == self.conversation_id)
.order_by(Message.created_at.desc())
.limit(500)
)
messages = list(session.scalars(stmt).all())
if not messages:
return []
# Extract thread messages using existing logic
thread_messages = extract_thread_messages(messages)
# For newly created message, its answer is temporarily empty, skip it
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
"""Deserialize a dict to PromptMessage based on role."""
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
return UserPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
return AssistantPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.SYSTEM, "system"):
return SystemPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.TOOL, "tool"):
return ToolPromptMessage.model_validate(msg_dict)
else:
return PromptMessage.model_validate(msg_dict)
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
"""Deserialize context data from outputs to list of PromptMessage."""
messages = []
for msg_dict in context_data:
try:
msg = self._deserialize_prompt_message(msg_dict)
msg = self._restore_multimodal_content(msg)
messages.append(msg)
except Exception as e:
logger.warning("Failed to deserialize prompt message: %s", e)
return messages
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
"""
Restore multimodal content (base64 or url) from file_ref.
When context is saved, base64_data is cleared to save storage space.
This method restores the content by parsing file_ref (format: "method:id_or_url").
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, restoring multimodal data from file references
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# restore_multimodal_content preserves the concrete subclass type
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
History is read directly from the last completed node execution's outputs["context"].
"""
_ = message_limit # unused, kept for interface compatibility
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Get the last completed workflow_run_id (contains accumulated context)
last_run_id = thread_workflow_run_ids[-1]
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
WorkflowNodeExecutionModel.node_id == self.node_id,
WorkflowNodeExecutionModel.status == "succeeded",
)
execution = session.scalars(stmt).first()
if not execution:
return []
outputs = execution.outputs_dict
if not outputs:
return []
context_data = outputs.get("context")
if not context_data or not isinstance(context_data, list):
return []
prompt_messages = self._deserialize_context(context_data)
if not prompt_messages:
return []
# Truncate by token limit
try:
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
while current_tokens > max_token_limit and len(prompt_messages) > 1:
prompt_messages.pop(0)
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
except Exception as e:
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages

View File

@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
class TokenBufferMemory(BaseMemory):
class TokenBufferMemory:
def __init__(
self,
conversation: Conversation,
@ -115,14 +115,10 @@ class TokenBufferMemory(BaseMemory):
return AssistantPromptMessage(content=prompt_message_contents)
def get_history_prompt_messages(
self,
*,
max_token_limit: int = 2000,
message_limit: int | None = None,
self, max_token_limit: int = 2000, message_limit: int | None = None
) -> Sequence[PromptMessage]:
"""
Get history prompt messages.
:param max_token_limit: max token limit
:param message_limit: message limit
"""
@ -204,3 +200,44 @@ class TokenBufferMemory(BaseMemory):
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
return prompt_messages
def get_history_prompt_text(
self,
human_prefix: str = "Human",
ai_prefix: str = "Assistant",
max_token_limit: int = 2000,
message_limit: int | None = None,
) -> str:
"""
Get history prompt text.
:param human_prefix: human prefix
:param ai_prefix: ai prefix
:param max_token_limit: max token limit
:param message_limit: message limit
:return:
"""
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
string_messages = []
for m in prompt_messages:
if m.role == PromptMessageRole.USER:
role = human_prefix
elif m.role == PromptMessageRole.ASSISTANT:
role = ai_prefix
else:
continue
if isinstance(m.content, list):
inner_msg = ""
for content in m.content:
if isinstance(content, TextPromptMessageContent):
inner_msg += f"{content.data}\n"
elif isinstance(content, ImagePromptMessageContent):
inner_msg += "[image]\n"
string_messages.append(f"{role}: {inner_msg.strip()}")
else:
message = f"{role}: {m.content}"
string_messages.append(message)
return "\n".join(string_messages)

View File

@ -91,9 +91,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
@ -279,5 +276,7 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
return super().is_empty() and not self.tool_call_id
if not super().is_empty() and not self.tool_call_id:
return False
return True

View File

@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.file import file_manager
from core.file.models import File
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
files: Sequence[File],
context: str | None,
memory_config: MemoryConfig | None,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
model_config: ModelConfigWithCredentialsEntity,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
def _set_histories_variable(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
raw_prompt: str,
role_prefix: MemoryConfig.RolePrefix,

View File

@ -1,4 +1,3 @@
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel
@ -6,13 +5,6 @@ from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageRole
class MemoryMode(StrEnum):
"""Memory mode for LLM nodes."""
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class ChatModelMessage(BaseModel):
"""
Chat Message.
@ -56,4 +48,3 @@ class MemoryConfig(BaseModel):
role_prefix: RolePrefix | None = None
window: WindowConfig
query_prompt_template: str | None = None
mode: MemoryMode = MemoryMode.CONVERSATION

View File

@ -1,7 +1,7 @@
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.message_entities import PromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey
@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
class PromptTransform:
def _append_chat_histories(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
prompt_messages: list[PromptMessage],
model_config: ModelConfigWithCredentialsEntity,
@ -52,7 +52,7 @@ class PromptTransform:
def _get_history_messages_from_memory(
self,
memory: BaseMemory,
memory: TokenBufferMemory,
memory_config: MemoryConfig,
max_token_limit: int,
human_prefix: str | None = None,
@ -73,7 +73,7 @@ class PromptTransform:
return memory.get_history_prompt_text(**kwargs)
def _get_history_messages_list_from_memory(
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
) -> list[PromptMessage]:
"""Get memory messages."""
return list(

View File

@ -1047,8 +1047,6 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
variable = variable_pool.get(tool_input.value)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
@ -1058,11 +1056,6 @@ class ToolManager:
elif tool_input.type == "mixed":
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.text
elif tool_input.type == "mention":
# Mention type not supported in agent mode
raise ToolParameterError(
f"Mention type not supported in agent for parameter '{parameter.name}'"
)
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
runtime_parameters[parameter.name] = parameter_value

View File

@ -5,6 +5,7 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from core.db.session_factory import session_factory
@ -28,21 +29,6 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
def _try_resolve_user_from_request() -> Account | EndUser | None:
"""
Try to resolve user from Flask request context.
Returns None if not in a request context or if user is not available.
"""
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
# Use _get_current_object() to dereference the proxy
user = getattr(current_user, "_get_current_object", lambda: current_user)()
# Check if we got a valid user object
if user is not None and hasattr(user, "id"):
return user
return None
class WorkflowTool(Tool):
"""
Workflow tool.
@ -223,13 +209,21 @@ class WorkflowTool(Tool):
Returns:
Account | EndUser | None: The resolved user object, or None if resolution fails.
"""
# Try to resolve user from request context first
user = _try_resolve_user_from_request()
if user is not None:
return user
if has_request_context():
return self._resolve_user_from_request()
else:
return self._resolve_user_from_database(user_id=user_id)
# Fall back to database resolution
return self._resolve_user_from_database(user_id=user_id)
def _resolve_user_from_request(self) -> Account | EndUser | None:
"""
Resolve user from Flask request context.
"""
try:
# Note: `current_user` is a LocalProxy. Never compare it with None directly.
return getattr(current_user, "_get_current_object", lambda: current_user)()
except Exception as e:
logger.warning("Failed to resolve user from request context: %s", e)
return None
def _resolve_user_from_database(self, user_id: str) -> Account | EndUser | None:
"""

View File

@ -4,7 +4,6 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
FileSegment,
@ -21,7 +20,6 @@ from .variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
ArrayVariable,
FileVariable,
@ -44,8 +42,6 @@ __all__ = [
"ArrayNumberVariable",
"ArrayObjectSegment",
"ArrayObjectVariable",
"ArrayPromptMessageSegment",
"ArrayPromptMessageVariable",
"ArraySegment",
"ArrayStringSegment",
"ArrayStringVariable",

View File

@ -6,7 +6,6 @@ from typing import Annotated, Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
from core.model_runtime.entities import PromptMessage
from .types import SegmentType
@ -209,15 +208,6 @@ class ArrayBooleanSegment(ArraySegment):
value: Sequence[bool]
class ArrayPromptMessageSegment(ArraySegment):
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
value: Sequence[PromptMessage]
def to_object(self):
"""Convert to JSON-serializable format for database storage and frontend."""
return [msg.model_dump() for msg in self.value]
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
@ -258,7 +248,6 @@ SegmentUnion: TypeAlias = Annotated[
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
),
Discriminator(get_segment_discriminator),
]

View File

@ -45,7 +45,6 @@ class SegmentType(StrEnum):
ARRAY_OBJECT = "array[object]"
ARRAY_FILE = "array[file]"
ARRAY_BOOLEAN = "array[boolean]"
ARRAY_PROMPT_MESSAGE = "array[message]"
NONE = "none"

View File

@ -3,10 +3,8 @@ from typing import Any
import orjson
from core.model_runtime.entities import PromptMessage
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from .segments import ArrayFileSegment, FileSegment, Segment
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
@ -18,7 +16,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
def segment_orjson_default(o: Any):
"""Default function for orjson serialization of Segment types"""
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
@ -26,8 +24,6 @@ def segment_orjson_default(o: Any):
return [segment_orjson_default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
elif isinstance(o, PromptMessage):
return o.model_dump()
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")

View File

@ -12,7 +12,6 @@ from .segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@ -111,10 +110,6 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
pass
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
pass
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
@ -165,7 +160,6 @@ Variable: TypeAlias = Annotated[
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),

View File

@ -1,22 +0,0 @@
"""
Execution Context - Context management for workflow execution.
This package provides Flask-independent context management for workflow
execution in multi-threaded environments.
"""
from core.workflow.context.execution_context import (
AppContext,
ExecutionContext,
IExecutionContext,
NullAppContext,
capture_current_context,
)
__all__ = [
"AppContext",
"ExecutionContext",
"IExecutionContext",
"NullAppContext",
"capture_current_context",
]

View File

@ -1,216 +0,0 @@
"""
Execution Context - Abstracted context management for workflow execution.
"""
import contextvars
from abc import ABC, abstractmethod
from collections.abc import Generator
from contextlib import AbstractContextManager, contextmanager
from typing import Any, Protocol, final, runtime_checkable
class AppContext(ABC):
"""
Abstract application context interface.
This abstraction allows workflow execution to work with or without Flask
by providing a common interface for application context management.
"""
@abstractmethod
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
pass
@abstractmethod
def get_extension(self, name: str) -> Any:
"""Get Flask extension by name (e.g., 'db', 'cache')."""
pass
@abstractmethod
def enter(self) -> AbstractContextManager[None]:
"""Enter the application context."""
pass
@runtime_checkable
class IExecutionContext(Protocol):
"""
Protocol for execution context.
This protocol defines the interface that all execution contexts must implement,
allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably.
"""
def __enter__(self) -> "IExecutionContext":
"""Enter the execution context."""
...
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
...
@property
def user(self) -> Any:
"""Get user object."""
...
@final
class ExecutionContext:
"""
Execution context for workflow execution in worker threads.
This class encapsulates all context needed for workflow execution:
- Application context (Flask app or standalone)
- Context variables for Python contextvars
- User information (optional)
It is designed to be serializable and passable to worker threads.
"""
def __init__(
self,
app_context: AppContext | None = None,
context_vars: contextvars.Context | None = None,
user: Any = None,
) -> None:
"""
Initialize execution context.
Args:
app_context: Application context (Flask or standalone)
context_vars: Python contextvars to preserve
user: User object (optional)
"""
self._app_context = app_context
self._context_vars = context_vars
self._user = user
@property
def app_context(self) -> AppContext | None:
"""Get application context."""
return self._app_context
@property
def context_vars(self) -> contextvars.Context | None:
"""Get context variables."""
return self._context_vars
@property
def user(self) -> Any:
"""Get user object."""
return self._user
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""
Enter this execution context.
This is a convenience method that creates a context manager.
"""
# Restore context variables if provided
if self._context_vars:
for var, val in self._context_vars.items():
var.set(val)
# Enter app context if available
if self._app_context is not None:
with self._app_context.enter():
yield
else:
yield
def __enter__(self) -> "ExecutionContext":
"""Enter the execution context."""
self._cm = self.enter()
self._cm.__enter__()
return self
def __exit__(self, *args: Any) -> None:
"""Exit the execution context."""
if hasattr(self, "_cm"):
self._cm.__exit__(*args)
class NullAppContext(AppContext):
"""
Null implementation of AppContext for non-Flask environments.
This is used when running without Flask (e.g., in tests or standalone mode).
"""
def __init__(self, config: dict[str, Any] | None = None) -> None:
"""
Initialize null app context.
Args:
config: Optional configuration dictionary
"""
self._config = config or {}
self._extensions: dict[str, Any] = {}
def get_config(self, key: str, default: Any = None) -> Any:
"""Get configuration value by key."""
return self._config.get(key, default)
def get_extension(self, name: str) -> Any:
"""Get extension by name."""
return self._extensions.get(name)
def set_extension(self, name: str, extension: Any) -> None:
"""Set extension by name."""
self._extensions[name] = extension
@contextmanager
def enter(self) -> Generator[None, None, None]:
"""Enter null context (no-op)."""
yield
class ExecutionContextBuilder:
"""
Builder for creating ExecutionContext instances.
This provides a fluent API for building execution contexts.
"""
def __init__(self) -> None:
self._app_context: AppContext | None = None
self._context_vars: contextvars.Context | None = None
self._user: Any = None
def with_app_context(self, app_context: AppContext) -> "ExecutionContextBuilder":
"""Set application context."""
self._app_context = app_context
return self
def with_context_vars(self, context_vars: contextvars.Context) -> "ExecutionContextBuilder":
"""Set context variables."""
self._context_vars = context_vars
return self
def with_user(self, user: Any) -> "ExecutionContextBuilder":
"""Set user."""
self._user = user
return self
def build(self) -> ExecutionContext:
"""Build the execution context."""
return ExecutionContext(
app_context=self._app_context,
context_vars=self._context_vars,
user=self._user,
)
def capture_current_context() -> IExecutionContext:
"""
Capture current execution context from the calling environment.
Returns:
IExecutionContext with captured context
"""
from context import capture_current_context
return capture_current_context()

File diff suppressed because it is too large Load Diff

View File

@ -63,7 +63,6 @@ class NodeType(StrEnum):
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
GROUP = "group"
@property
def is_trigger_node(self) -> bool:
@ -253,7 +252,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
COMPLETED_REASON = "completed_reason" # completed reason for loop node
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
class WorkflowNodeExecutionStatus(StrEnum):

View File

@ -307,14 +307,7 @@ class Graph:
if not node_configs:
raise ValueError("Graph must have at least one node")
# Filter out UI-only node types:
# - custom-note: top-level type (node_config.type == "custom-note")
# - group: data-level type (node_config.data.type == "group")
node_configs = [
node_config
for node_config in node_configs
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
]
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)

View File

@ -93,8 +93,8 @@ class EventHandler:
Args:
event: The event to handle
"""
# Events in loops, iterations, or extractor groups are always collected
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
# Events in loops or iterations are always collected
if event.in_loop_id or event.in_iteration_id:
self._event_collector.collect(event)
return
return self._dispatch(event)
@ -125,11 +125,6 @@ class EventHandler:
Args:
event: The node started event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_started(event)
return
# Track execution in domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
@ -169,11 +164,6 @@ class EventHandler:
Args:
event: The node succeeded event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_success(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
@ -236,11 +226,6 @@ class EventHandler:
Args:
event: The node failed event
"""
# Check if this is an extractor node (has parent_node_id)
if self._is_extractor_node(event.node_id):
self._handle_extractor_node_failed(event)
return
# Update domain model
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_failed(event.error)
@ -360,57 +345,3 @@ class EventHandler:
self._graph_runtime_state.set_output("answer", value)
else:
self._graph_runtime_state.set_output(key, value)
def _is_extractor_node(self, node_id: str) -> bool:
"""
Check if node_id represents an extractor node (has parent_node_id).
Extractor nodes extract values from list[PromptMessage] for their parent node.
They have a parent_node_id field pointing to their parent node.
"""
node = self._graph.nodes.get(node_id)
if node is None:
return False
return node.node_data.is_extractor_node
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
"""
Handle extractor node started event.
Extractor nodes don't need full execution tracking, just collect the event.
"""
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
# Collect the event
self._event_collector.collect(event)
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
"""
Handle extractor node success event.
Extractor nodes need special handling:
- Store outputs in variable pool (for reference by other nodes)
- Accumulate token usage
- Collect the event for logging
- Do NOT process edges or enqueue next nodes (parent node handles that)
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
# Collect the event
self._event_collector.collect(event)
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
"""
Handle extractor node failed event.
Extractor node failures are collected for logging,
but the parent node is responsible for handling the error.
"""
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Collect the event for logging
self._event_collector.collect(event)

View File

@ -7,13 +7,15 @@ Domain-Driven Design principles for improved maintainability and testability.
from __future__ import annotations
import contextvars
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from core.workflow.context import capture_current_context
from flask import Flask, current_app
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@ -157,8 +159,17 @@ class GraphEngine:
self._layers: list[GraphEngineLayer] = []
# === Worker Pool Setup ===
# Capture execution context for worker threads
execution_context = capture_current_context()
# Capture Flask app context for worker threads
flask_app: Flask | None = None
try:
app = current_app._get_current_object() # type: ignore
if isinstance(app, Flask):
flask_app = app
except RuntimeError:
pass
# Capture context variables for worker threads
context_vars = contextvars.copy_context()
# Create worker pool for parallel node execution
self._worker_pool = WorkerPool(
@ -166,7 +177,8 @@ class GraphEngine:
event_queue=self._event_queue,
graph=self._graph,
layers=self._layers,
execution_context=execution_context,
flask_app=flask_app,
context_vars=context_vars,
min_workers=self._min_workers,
max_workers=self._max_workers,
scale_up_threshold=self._scale_up_threshold,

View File

@ -68,7 +68,6 @@ class _NodeRuntimeSnapshot:
predecessor_node_id: str | None
iteration_id: str | None
loop_id: str | None
mention_parent_id: str | None
created_at: datetime
@ -231,7 +230,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
}
domain_execution = WorkflowNodeExecution(
@ -258,7 +256,6 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
predecessor_node_id=event.predecessor_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
mention_parent_id=event.in_mention_parent_id,
created_at=event.start_at,
)
self._node_snapshots[event.id] = snapshot

View File

@ -5,27 +5,26 @@ Workers pull node IDs from the ready_queue, execute nodes, and push events
to the event_queue for the dispatcher to process.
"""
import contextvars
import queue
import threading
import time
from collections.abc import Sequence
from datetime import datetime
from typing import TYPE_CHECKING, final
from typing import final
from uuid import uuid4
from flask import Flask
from typing_extensions import override
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphNodeEventBase, NodeRunFailedEvent
from core.workflow.nodes.base.node import Node
from libs.flask_utils import preserve_flask_contexts
from .ready_queue import ReadyQueue
if TYPE_CHECKING:
pass
@final
class Worker(threading.Thread):
@ -45,7 +44,8 @@ class Worker(threading.Thread):
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
flask_app: Flask | None = None,
context_vars: contextvars.Context | None = None,
) -> None:
"""
Initialize worker thread.
@ -56,17 +56,19 @@ class Worker(threading.Thread):
graph: Graph containing nodes to execute
layers: Graph engine layers for node execution hooks
worker_id: Unique identifier for this worker
execution_context: Optional execution context for context preservation
flask_app: Optional Flask application for context preservation
context_vars: Optional context variables to preserve in worker thread
"""
super().__init__(name=f"GraphWorker-{worker_id}", daemon=True)
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._flask_app = flask_app
self._context_vars = context_vars
self._last_task_time = time.time()
self._stop_event = stop_event
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
@ -133,9 +135,11 @@ class Worker(threading.Thread):
error: Exception | None = None
# Execute the node with preserved context if execution context is provided
if self._execution_context is not None:
with self._execution_context:
if self._flask_app and self._context_vars:
with preserve_flask_contexts(
flask_app=self._flask_app,
context_vars=self._context_vars,
):
self._invoke_node_run_start_hooks(node)
try:
node_events = node.run()

View File

@ -8,10 +8,9 @@ DynamicScaler, and WorkerFactory into a single class.
import logging
import queue
import threading
from typing import final
from typing import TYPE_CHECKING, final
from configs import dify_config
from core.workflow.context import IExecutionContext
from core.workflow.graph import Graph
from core.workflow.graph_events import GraphNodeEventBase
@ -21,6 +20,11 @@ from ..worker import Worker
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from contextvars import Context
from flask import Flask
@final
class WorkerPool:
@ -38,7 +42,8 @@ class WorkerPool:
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
execution_context: IExecutionContext | None = None,
flask_app: "Flask | None" = None,
context_vars: "Context | None" = None,
min_workers: int | None = None,
max_workers: int | None = None,
scale_up_threshold: int | None = None,
@ -52,7 +57,8 @@ class WorkerPool:
event_queue: Queue for worker events
graph: The workflow graph
layers: Graph engine layers for node execution hooks
execution_context: Optional execution context for context preservation
flask_app: Optional Flask app for context preservation
context_vars: Optional context variables
min_workers: Minimum number of workers
max_workers: Maximum number of workers
scale_up_threshold: Queue depth to trigger scale up
@ -61,7 +67,8 @@ class WorkerPool:
self._ready_queue = ready_queue
self._event_queue = event_queue
self._graph = graph
self._execution_context = execution_context
self._flask_app = flask_app
self._context_vars = context_vars
self._layers = layers
# Scaling parameters with defaults
@ -145,7 +152,8 @@ class WorkerPool:
graph=self._graph,
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
flask_app=self._flask_app,
context_vars=self._context_vars,
stop_event=self._stop_event,
)

View File

@ -21,12 +21,6 @@ class GraphNodeEventBase(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: str | None = None
"""loop id if node is in loop"""
in_mention_parent_id: str | None = None
"""Parent node id if this is an extractor node event.
When set, indicates this event belongs to an extractor node that
is extracting values for the specified parent node.
"""
# The version of the node, or "1" if not specified.
node_version: str = "1"

View File

@ -12,20 +12,11 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import MemoryMode
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import (
ToolIdentity,
@ -145,9 +136,6 @@ class AgentNode(Node[AgentNodeData]):
)
return
# Fetch memory for node memory saving
memory = self._fetch_memory_for_save()
try:
yield from self._transform_message(
messages=message_stream,
@ -161,7 +149,6 @@ class AgentNode(Node[AgentNodeData]):
node_type=self.node_type,
node_id=self._node_id,
node_execution_id=self.id,
memory=memory,
)
except PluginDaemonClientSideError as e:
transform_error = AgentMessageTransformError(
@ -408,20 +395,8 @@ class AgentNode(Node[AgentNodeData]):
icon = None
return icon
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
"""
node_data = self.node_data
memory_config = node_data.memory
if not memory_config:
return None
# get conversation id (required for both modes in Chatflow)
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
# get conversation id
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.CONVERSATION_ID]
)
@ -429,26 +404,16 @@ class AgentNode(Node[AgentNodeData]):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if memory_config.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(
Conversation.app_id == self.app_id, Conversation.id == conversation_id
)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
@ -492,136 +457,6 @@ class AgentNode(Node[AgentNodeData]):
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _fetch_memory_for_save(self) -> BaseMemory | None:
"""
Fetch memory instance for saving node memory.
This is a simplified version that doesn't require model_instance.
"""
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
node_data = self.node_data
if not node_data.memory:
return None
# Get conversation_id
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_var, StringSegment):
return None
conversation_id = conversation_id_var.value
# Return appropriate memory type based on mode
if node_data.memory.mode == MemoryMode.NODE:
# For node memory, we need a model_instance for token counting
# Use a simple default model for this purpose
try:
model_instance = ModelManager().get_default_model_instance(
tenant_id=self.tenant_id,
model_type=ModelType.LLM,
)
except Exception:
return None
return NodeTokenBufferMemory(
app_id=self.app_id,
conversation_id=conversation_id,
node_id=self._node_id,
tenant_id=self.tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory doesn't need saving here
return None
def _build_context(
self,
parameters_for_log: dict[str, Any],
user_query: str,
assistant_response: str,
agent_logs: list[AgentLogEvent],
) -> list[PromptMessage]:
"""
Build context from user query, tool calls, and assistant response.
Format: user -> assistant(with tool_calls) -> tool -> assistant
The context includes:
- Current user query (always present, may be empty)
- Assistant message with tool_calls (if tools were called)
- Tool results
- Assistant's final response
"""
context_messages: list[PromptMessage] = []
# Always add user query (even if empty, to maintain conversation structure)
context_messages.append(UserPromptMessage(content=user_query or ""))
# Extract actual tool calls from agent logs
# Only include logs with label starting with "CALL " - these are real tool invocations
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
for log in agent_logs:
if log.status == "success" and log.label and log.label.startswith("CALL "):
# Extract tool name from label (format: "CALL tool_name")
tool_name = log.label[5:] # Remove "CALL " prefix
tool_call_id = log.message_id
# Parse tool response from data
data = log.data or {}
tool_response = ""
# Try to extract the actual tool response
if "tool_response" in data:
tool_response = data["tool_response"]
elif "output" in data:
tool_response = data["output"]
elif "result" in data:
tool_response = data["result"]
if isinstance(tool_response, dict):
tool_response = str(tool_response)
# Get tool input for arguments
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
if isinstance(tool_input, dict):
import json
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
else:
tool_input_str = str(tool_input) if tool_input else ""
if tool_response:
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_name,
arguments=tool_input_str,
),
)
)
tool_results.append((tool_call_id, tool_name, str(tool_response)))
# Add assistant message with tool_calls if there were tool calls
if tool_calls:
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
# Add tool result messages
for tool_call_id, tool_name, result in tool_results:
context_messages.append(
ToolPromptMessage(
content=result,
tool_call_id=tool_call_id,
name=tool_name,
)
)
# Add final assistant response
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
@ -632,7 +467,6 @@ class AgentNode(Node[AgentNodeData]):
node_type: NodeType,
node_id: str,
node_execution_id: str,
memory: BaseMemory | None = None,
) -> Generator[NodeEventBase, None, None]:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@ -877,12 +711,6 @@ class AgentNode(Node[AgentNodeData]):
is_final=True,
)
# Get user query from parameters for building context
user_query = parameters_for_log.get("query", "")
# Build context from history, user query, tool calls and assistant response
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -891,7 +719,6 @@ class AgentNode(Node[AgentNodeData]):
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
"context": context,
**variables,
},
metadata={

View File

@ -1,10 +1,4 @@
from .entities import (
BaseIterationNodeData,
BaseIterationState,
BaseLoopNodeData,
BaseLoopState,
BaseNodeData,
)
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
from .usage_tracking_mixin import LLMUsageTrackingMixin
__all__ = [

View File

@ -175,16 +175,6 @@ class BaseNodeData(ABC, BaseModel):
default_value: list[DefaultValue] | None = None
retry_config: RetryConfig = RetryConfig()
# Parent node ID when this node is used as an extractor.
# If set, this node is an "attached" extractor node that extracts values
# from list[PromptMessage] for the parent node's parameters.
parent_node_id: str | None = None
@property
def is_extractor_node(self) -> bool:
"""Check if this node is an extractor node (has parent_node_id)."""
return self.parent_node_id is not None
@property
def default_value_dict(self) -> dict[str, Any]:
if self.default_value:

View File

@ -270,87 +270,10 @@ class Node(Generic[NodeDataT]):
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
"""
Find all extractor node configurations that have parent_node_id == self._node_id.
Returns:
List of node configuration dicts for extractor nodes
"""
nodes = self.graph_config.get("nodes", [])
extractor_configs = []
for node_config in nodes:
node_data = node_config.get("data", {})
if node_data.get("parent_node_id") == self._node_id:
extractor_configs.append(node_config)
return extractor_configs
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Execute all extractor nodes associated with this node.
Extractor nodes are nodes with parent_node_id == self._node_id.
They are executed before the main node to extract values from list[PromptMessage].
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
extractor_configs = self._find_extractor_node_configs()
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
if not extractor_configs:
return
for config in extractor_configs:
node_id = config.get("id")
node_data = config.get("data", {})
node_type_str = node_data.get("type")
if not node_id or not node_type_str:
continue
# Get node class
try:
node_type = NodeType(node_type_str)
except ValueError:
continue
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
if not node_mapping:
continue
node_version = str(node_data.get("version", "1"))
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
if not node_cls:
continue
# Instantiate and execute the extractor node
extractor_node = node_cls(
id=node_id,
config=config,
graph_init_params=self._graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
# Execute and process extractor node events
for event in extractor_node.run():
# Tag event with parent node id for stream ordering and history tracking
if isinstance(event, GraphNodeEventBase):
event.in_mention_parent_id = self._node_id
if isinstance(event, NodeRunSucceededEvent):
# Store extractor node outputs in variable pool
outputs: Mapping[str, Any] = event.node_run_result.outputs
for variable_name, variable_value in outputs.items():
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
if not isinstance(event, NodeRunStreamChunkEvent):
yield event
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_mention_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
id=execution_id,

View File

@ -1,9 +1,11 @@
import contextvars
import logging
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
@ -37,6 +39,7 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from core.workflow.runtime import VariablePool
from libs.datetime_utils import naive_utc_now
from libs.flask_utils import preserve_flask_contexts
from .exc import (
InvalidIteratorValueError,
@ -48,7 +51,6 @@ from .exc import (
)
if TYPE_CHECKING:
from core.workflow.context import IExecutionContext
from core.workflow.graph_engine import GraphEngine
logger = logging.getLogger(__name__)
@ -250,7 +252,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self._execute_single_iteration_parallel,
index=index,
item=item,
execution_context=self._capture_execution_context(),
flask_app=current_app._get_current_object(), # type: ignore
context_vars=contextvars.copy_context(),
)
future_to_index[future] = index
@ -303,10 +306,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
self,
index: int,
item: object,
execution_context: "IExecutionContext",
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with execution_context:
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
events: list[GraphNodeEventBase] = []
outputs_temp: list[object] = []
@ -335,12 +339,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
graph_engine.graph_runtime_state.llm_usage,
)
def _capture_execution_context(self) -> "IExecutionContext":
"""Capture current execution context for parallel iterations."""
from core.workflow.context import capture_current_context
return capture_current_context()
def _handle_iteration_success(
self,
started_at: datetime,

View File

@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
from typing import Any, Literal
from pydantic import BaseModel, ConfigDict, Field, field_validator
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
@ -58,28 +58,9 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class PromptMessageContext(BaseModel):
"""Context variable reference in prompt template.
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
This will be expanded to list[PromptMessage] at runtime.
"""
model_config = ConfigDict(populate_by_name=True)
value_selector: Sequence[str] = Field(alias="$context")
# Union type for prompt template items (static message or context variable reference)
PromptTemplateItem: TypeAlias = Annotated[
LLMNodeChatModelMessage | PromptMessageContext,
Field(discriminator=None),
]
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
memory: MemoryConfig | None = None
context: ContextConfig

View File

@ -8,20 +8,12 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.file.models import File
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes.llm.entities import ModelConfig
@ -94,56 +86,25 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
def fetch_memory(
variable_pool: VariablePool,
app_id: str,
tenant_id: str,
node_data_memory: MemoryConfig | None,
model_instance: ModelInstance,
node_id: str = "",
) -> BaseMemory | None:
"""
Fetch memory based on configuration mode.
Returns TokenBufferMemory for conversation mode (default),
or NodeTokenBufferMemory for node mode (Chatflow only).
:param variable_pool: Variable pool containing system variables
:param app_id: Application ID
:param tenant_id: Tenant ID
:param node_data_memory: Memory configuration
:param model_instance: Model instance for token counting
:param node_id: Node ID in the workflow (required for node mode)
:return: Memory instance or None if not applicable
"""
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
) -> TokenBufferMemory | None:
if not node_data_memory:
return None
# Get conversation_id from variable pool (required for both modes in Chatflow)
# get conversation id
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
if not isinstance(conversation_id_variable, StringSegment):
return None
conversation_id = conversation_id_variable.value
# Return appropriate memory type based on mode
if node_data_memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
if not node_id:
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return NodeTokenBufferMemory(
app_id=app_id,
conversation_id=conversation_id,
node_id=node_id,
tenant_id=tenant_id,
model_instance=model_instance,
)
else:
# Conversation-level memory (default)
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
conversation = session.scalar(stmt)
if not conversation:
return None
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
@ -209,87 +170,3 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
session.execute(stmt)
session.commit()
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
If file_ref is present, clears base64_data and url (they can be restored later).
Otherwise, truncates base64_data as fallback for legacy data.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, handling multi-modal data based on file_ref availability
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
if item.file_ref:
# Clear base64 and url, keep file_ref for later restoration
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
else:
# Fallback: truncate base64_data if no file_ref (legacy data)
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
"""
Restore multimodal content (base64 or url) in a list of PromptMessages.
When context is saved, base64_data is cleared to save storage space.
This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.
Args:
messages: List of PromptMessages that may contain truncated multimodal content
Returns:
List of PromptMessages with restored multimodal content
"""
from core.file import file_manager
return [_restore_message_content(msg, file_manager) for msg in messages]
def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
"""Restore multimodal content in a single PromptMessage."""
content = message.content
if content is None or isinstance(content, str):
return message
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})

View File

@ -7,7 +7,7 @@ import logging
import re
import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
@ -16,7 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
@ -51,7 +51,6 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.signature import sign_upload_file
from core.variables import (
ArrayFileSegment,
ArrayPromptMessageSegment,
ArraySegment,
FileSegment,
NoneSegment,
@ -88,7 +87,6 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
PromptMessageContext,
)
from .exc import (
InvalidContextStructureError,
@ -161,9 +159,8 @@ class LLMNode(Node[LLMNodeData]):
variable_pool = self.graph_runtime_state.variable_pool
try:
# Parse prompt template to separate static messages and context references
prompt_template = self.node_data.prompt_template
static_messages, context_refs, template_order = self._parse_prompt_template()
# init messages template
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
# fetch variables and fetch values from variable pool
inputs = self._fetch_inputs(node_data=self.node_data)
@ -211,10 +208,8 @@ class LLMNode(Node[LLMNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=self.node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
query: str | None = None
@ -225,40 +220,21 @@ class LLMNode(Node[LLMNodeData]):
):
query = query_variable.text
# Get prompt messages
prompt_messages: Sequence[PromptMessage]
stop: Sequence[str] | None
if isinstance(prompt_template, list) and context_refs:
prompt_messages, stop = self._build_prompt_messages_with_context(
context_refs=context_refs,
template_order=template_order,
static_messages=static_messages,
query=query,
files=files,
context=context,
memory=memory,
model_config=model_config,
context_files=context_files,
)
else:
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
self.node_data.prompt_template,
),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# handle invoke result
generator = LLMNode.invoke_llm(
@ -274,7 +250,6 @@ class LLMNode(Node[LLMNodeData]):
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self.node_data.reasoning_format,
tenant_id=self.tenant_id,
)
structured_output: LLMStructuredOutput | None = None
@ -326,7 +301,6 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": llm_utils.build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
@ -393,7 +367,6 @@ class LLMNode(Node[LLMNodeData]):
node_id: str,
node_type: NodeType,
reasoning_format: Literal["separated", "tagged"] = "tagged",
tenant_id: str | None = None,
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@ -417,7 +390,6 @@ class LLMNode(Node[LLMNodeData]):
stop=list(stop or []),
stream=True,
user=user_id,
tenant_id=tenant_id,
)
else:
request_start_time = time.perf_counter()
@ -609,212 +581,6 @@ class LLMNode(Node[LLMNodeData]):
return messages
def _parse_prompt_template(
self,
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
"""
Parse prompt_template to separate static messages and context references.
Returns:
Tuple of (static_messages, context_refs, template_order)
- static_messages: list of LLMNodeChatModelMessage
- context_refs: list of PromptMessageContext
- template_order: list of (index, type) tuples preserving original order
"""
prompt_template = self.node_data.prompt_template
static_messages: list[LLMNodeChatModelMessage] = []
context_refs: list[PromptMessageContext] = []
template_order: list[tuple[int, str]] = []
if isinstance(prompt_template, list):
for idx, item in enumerate(prompt_template):
if isinstance(item, PromptMessageContext):
context_refs.append(item)
template_order.append((idx, "context"))
else:
static_messages.append(item)
template_order.append((idx, "static"))
# Transform static messages for jinja2
if static_messages:
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
return static_messages, context_refs, template_order
def _build_prompt_messages_with_context(
self,
*,
context_refs: list[PromptMessageContext],
template_order: list[tuple[int, str]],
static_messages: list[LLMNodeChatModelMessage],
query: str | None,
files: Sequence[File],
context: str | None,
memory: BaseMemory | None,
model_config: ModelConfigWithCredentialsEntity,
context_files: list[File],
) -> tuple[list[PromptMessage], Sequence[str] | None]:
"""
Build prompt messages by combining static messages and context references in DSL order.
Returns:
Tuple of (prompt_messages, stop_sequences)
"""
variable_pool = self.graph_runtime_state.variable_pool
# Process messages in DSL order: iterate once and handle each type directly
combined_messages: list[PromptMessage] = []
context_idx = 0
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
# Handle context reference
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
# Restore multimodal content (base64/url) that was truncated when saving context
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
combined_messages.extend(restored_messages)
context_idx += 1
else:
# Handle static message
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(processed_msgs)
static_idx += 1
# Append memory messages
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
# Append current query if provided
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
# Handle files (sys_files and context_files)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
# Filter empty messages and get stop sequences
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""Append sys_files and context_files to messages."""
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
# Handle sys_files (will be deprecated later)
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
# Handle context_files
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
"""Filter empty messages and unsupported content types."""
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
# Skip non-text content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
# Simplify single text content
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
"""Get stop sequences from model config."""
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}
@ -1012,7 +778,7 @@ class LLMNode(Node[LLMNodeData]):
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: BaseMemory | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
memory_config: MemoryConfig | None = None,
@ -1571,7 +1337,7 @@ def _calculate_rest_token(
def _handle_memory_chat_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> Sequence[PromptMessage]:
@ -1588,7 +1354,7 @@ def _handle_memory_chat_mode(
def _handle_memory_completion_mode(
*,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
) -> str:

View File

@ -7,7 +7,7 @@ from typing import Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import ImagePromptMessageContent
from core.model_runtime.entities.llm_entities import LLMUsage
@ -145,10 +145,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
if (
@ -246,10 +244,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
# Build context from prompt messages and response
assistant_response = json.dumps(result, ensure_ascii=False)
context = llm_utils.build_context(prompt_messages, assistant_response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@ -258,7 +252,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
"context": context,
**result,
},
metadata={
@ -306,7 +299,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
@ -388,7 +381,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -426,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -460,7 +453,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
query: str,
variable_pool: VariablePool,
model_config: ModelConfigWithCredentialsEntity,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
files: Sequence[File],
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
) -> list[PromptMessage]:
@ -688,7 +681,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
model_mode = ModelMode(node_data.model.mode)
@ -715,7 +708,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
node_data: ParameterExtractorNodeData,
query: str,
variable_pool: VariablePool,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.base import BaseMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
from core.model_runtime.utils.encoders import jsonable_encoder
@ -96,10 +96,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
tenant_id=self.tenant_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
node_id=self._node_id,
)
# fetch instruction
node_data.instruction = node_data.instruction or ""
@ -199,15 +197,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# Build context from prompt messages and response
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
context = llm_utils.build_context(prompt_messages, assistant_response)
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
"context": context,
}
return NodeRunResult(
@ -319,7 +312,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self,
node_data: QuestionClassifierNodeData,
query: str,
memory: BaseMemory | None,
memory: TokenBufferMemory | None,
max_token_limit: int = 2000,
):
model_mode = ModelMode(node_data.model.mode)

View File

@ -1,63 +1,11 @@
import re
from collections.abc import Sequence
from typing import Any, Literal, Self, Union
from typing import Any, Literal, Union
from pydantic import BaseModel, field_validator, model_validator
from pydantic import BaseModel, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.base.entities import BaseNodeData
# Pattern to match mention value format: {{@node.context@}}instruction
# The placeholder {{@node.context@}} must appear at the beginning
# Format: {{@agent_node_id.context@}} where agent_node_id is dynamic, context is fixed
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
def parse_mention_value(value: str) -> tuple[str, str]:
"""Parse mention value into (node_id, instruction).
Args:
value: The mention value string like "{{@llm.context@}}extract keywords"
Returns:
Tuple of (node_id, instruction)
Raises:
ValueError: If value format is invalid
"""
match = MENTION_VALUE_PATTERN.match(value)
if not match:
raise ValueError(
"For mention type, value must start with {{@node.context@}} placeholder, "
"e.g., '{{@llm.context@}}extract keywords'"
)
return match.group(1), match.group(2)
class MentionConfig(BaseModel):
"""Configuration for extracting value from context variable.
Used when a tool parameter needs to be extracted from list[PromptMessage]
context using an extractor LLM node.
Note: instruction is embedded in the value field as "{{@node.context@}}instruction"
"""
# ID of the extractor LLM node
extractor_node_id: str
# Output variable selector from extractor node
# e.g., ["text"], ["structured_output", "query"]
output_selector: Sequence[str]
# Strategy when output is None
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
# Default value when null_strategy is "use_default"
# Type should match the parameter's expected type
default_value: Any = None
class ToolEntity(BaseModel):
provider_id: str
@ -87,9 +35,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant", "mention"]
# Required config for mention type, extracting value from context variable
mention_config: MentionConfig | None = None
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
@ -102,9 +48,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "mention":
# Skip here, will be validated in model_validator
pass
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
@ -115,26 +58,6 @@ class ToolNodeData(BaseNodeData, ToolEntity):
raise ValueError("value must be a string, int, float, bool or dict")
return typ
@model_validator(mode="after")
def check_mention_type(self) -> Self:
"""Validate mention type with mention_config."""
if self.type != "mention":
return self
value = self.value
if value is None:
return self
if not isinstance(value, str):
raise ValueError("value must be a string for mention type")
# For mention type, value must match format: {{@node.context@}}instruction
# This will raise ValueError if format is invalid
parse_mention_value(value)
# mention_config is required for mention type
if self.mention_config is None:
raise ValueError("mention_config is required for mention type")
return self
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
# If this value is None, it indicates this is a previous version

View File

@ -1,10 +1,7 @@
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
logger = logging.getLogger(__name__)
from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -187,7 +184,6 @@ class ToolNode(Node[ToolNodeData]):
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
variable_pool (VariablePool): The variable pool containing the variables.
node_data (ToolNodeData): The data associated with the tool node.
for_log (bool): Whether to generate parameters for logging.
Returns:
Mapping[str, Any]: A dictionary containing the generated parameters.
@ -203,37 +199,14 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
if not isinstance(tool_input.value, list):
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
selector = tool_input.value
variable = variable_pool.get(selector)
variable = variable_pool.get(tool_input.value)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {selector} does not exist")
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
continue
parameter_value = variable.value
elif tool_input.type == "mention":
# Mention type: get value from extractor node's output
if tool_input.mention_config is None:
raise ToolParameterError(
f"mention_config is required for mention type parameter '{parameter_name}'"
)
mention_config = tool_input.mention_config.model_dump()
try:
parameter_value, found = variable_pool.resolve_mention(
mention_config, parameter_name=parameter_name
)
if not found and parameter.required:
raise ToolParameterError(
f"Extractor output not found for required parameter '{parameter_name}'"
)
if not found:
continue
except ValueError as e:
raise ToolParameterError(str(e)) from e
elif tool_input.type in {"mixed", "constant"}:
template = str(tool_input.value)
segment_group = variable_pool.convert_template(template)
segment_group = variable_pool.convert_template(str(tool_input.value))
parameter_value = segment_group.log if for_log else segment_group.text
else:
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
@ -515,12 +488,8 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == "variable":
if isinstance(input.value, list):
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "mention":
# Mention type: value is handled by extractor node, no direct variable reference
pass
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
elif input.type == "constant":
pass

View File

@ -268,58 +268,6 @@ class VariablePool(BaseModel):
continue
self.add(selector, value)
def resolve_mention(
self,
mention_config: Mapping[str, Any],
/,
*,
parameter_name: str = "",
) -> tuple[Any, bool]:
"""
Resolve a mention parameter value from an extractor node's output.
Mention parameters reference values extracted by an extractor LLM node
from list[PromptMessage] context.
Args:
mention_config: A dict containing:
- extractor_node_id: ID of the extractor LLM node
- output_selector: Selector path for the output variable (e.g., ["text"])
- null_strategy: "raise_error" or "use_default"
- default_value: Value to use when null_strategy is "use_default"
parameter_name: Name of the parameter being resolved (for error messages)
Returns:
Tuple of (resolved_value, found):
- resolved_value: The extracted value, or default_value if not found
- found: True if value was found, False if using default
Raises:
ValueError: If extractor_node_id is missing, or if null_strategy is
"raise_error" and the value is not found
"""
extractor_node_id = mention_config.get("extractor_node_id")
if not extractor_node_id:
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
output_selector = list(mention_config.get("output_selector", []))
null_strategy = mention_config.get("null_strategy", "raise_error")
default_value = mention_config.get("default_value")
# Build full selector: [extractor_node_id, ...output_selector]
full_selector = [extractor_node_id] + output_selector
variable = self.get(full_selector)
if variable is None:
if null_strategy == "use_default":
return default_value, False
raise ValueError(
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
f"not found for parameter '{parameter_name}'"
)
return variable.value, True
@classmethod
def empty(cls) -> VariablePool:
"""Create an empty variable pool."""

View File

@ -4,7 +4,6 @@ from uuid import uuid4
from configs import dify_config
from core.file import File
from core.model_runtime.entities import PromptMessage
from core.variables.exc import VariableError
from core.variables.segments import (
ArrayAnySegment,
@ -12,7 +11,6 @@ from core.variables.segments import (
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayPromptMessageSegment,
ArraySegment,
ArrayStringSegment,
BooleanSegment,
@ -31,7 +29,6 @@ from core.variables.variables import (
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayPromptMessageVariable,
ArrayStringVariable,
BooleanVariable,
FileVariable,
@ -64,7 +61,6 @@ SEGMENT_TO_VARIABLE_MAP = {
ArrayFileSegment: ArrayFileVariable,
ArrayNumberSegment: ArrayNumberVariable,
ArrayObjectSegment: ArrayObjectVariable,
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
ArrayStringSegment: ArrayStringVariable,
BooleanSegment: BooleanVariable,
FileSegment: FileVariable,
@ -160,13 +156,7 @@ def build_segment(value: Any, /) -> Segment:
return ObjectSegment(value=value)
if isinstance(value, File):
return FileSegment(value=value)
if isinstance(value, PromptMessage):
# Single PromptMessage should be wrapped in a list
return ArrayPromptMessageSegment(value=[value])
if isinstance(value, list):
# Check if all items are PromptMessage
if value and all(isinstance(item, PromptMessage) for item in value):
return ArrayPromptMessageSegment(value=value)
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if all(isinstance(item, ArraySegment) for item in items):
@ -210,7 +200,6 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
}
@ -285,10 +274,6 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
):
segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")

View File

@ -1,74 +0,0 @@
"""
Workspace permission helper functions.
These helpers check both billing/plan level and workspace-specific policy level permissions.
Checks are performed at two levels:
1. Billing/plan level - via FeatureService (e.g., SANDBOX plan restrictions)
2. Workspace policy level - via EnterpriseService (admin-configured per workspace)
"""
import logging
from werkzeug.exceptions import Forbidden
from configs import dify_config
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
logger = logging.getLogger(__name__)
def check_workspace_member_invite_permission(workspace_id: str) -> None:
"""
Check if workspace allows member invitations at both billing and policy levels.
Checks performed:
1. Billing/plan level - For future expansion (currently no plan-level restriction)
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits member invitations
"""
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_member_invite:
raise Forbidden("Workspace policy prohibits member invitations")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace invite permission for %s", workspace_id)
def check_workspace_owner_transfer_permission(workspace_id: str) -> None:
"""
Check if workspace allows owner transfer at both billing and policy levels.
Checks performed:
1. Billing/plan level - SANDBOX plan blocks owner transfer
2. Enterprise policy level - Admin-configured workspace permission
Args:
workspace_id: The workspace ID to check permissions for
Raises:
Forbidden: If either billing plan or workspace policy prohibits ownership transfer
"""
features = FeatureService.get_features(workspace_id)
if not features.is_allow_transfer_workspace:
raise Forbidden("Your current plan does not allow workspace ownership transfer")
# Check enterprise workspace policy level (only if enterprise enabled)
if dify_config.ENTERPRISE_ENABLED:
try:
permission = EnterpriseService.WorkspacePermissionService.get_permission(workspace_id)
if not permission.allow_owner_transfer:
raise Forbidden("Workspace policy prohibits ownership transfer")
except Forbidden:
raise
except Exception:
logger.exception("Failed to check workspace transfer permission for %s", workspace_id)

View File

@ -1285,7 +1285,7 @@ class WorkflowDraftVariable(Base):
# which may differ from the original value's type. Typically, they are the same,
# but in cases where the structurally truncated value still exceeds the size limit,
# text slicing is applied, and the `value_type` is converted to `STRING`.
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
# The variable's value serialized as a JSON string
#
@ -1659,7 +1659,7 @@ class WorkflowDraftVariableFile(Base):
# The `value_type` field records the type of the original value.
value_type: Mapped[SegmentType] = mapped_column(
EnumText(SegmentType, length=21),
EnumText(SegmentType, length=20),
nullable=False,
)

View File

@ -1381,11 +1381,6 @@ class RegisterService:
normalized_email = email.lower()
"""Invite new member"""
# Check workspace permission for member invitations
from libs.workspace_permission import check_workspace_member_invite_permission
check_workspace_member_invite_permission(tenant.id)
with Session(db.engine) as session:
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)

View File

@ -13,23 +13,6 @@ class WebAppSettings(BaseModel):
)
class WorkspacePermission(BaseModel):
workspace_id: str = Field(
description="The ID of the workspace.",
alias="workspaceId",
)
allow_member_invite: bool = Field(
description="Whether to allow members to invite new members to the workspace.",
default=False,
alias="allowMemberInvite",
)
allow_owner_transfer: bool = Field(
description="Whether to allow owners to transfer ownership of the workspace.",
default=False,
alias="allowOwnerTransfer",
)
class EnterpriseService:
@classmethod
def get_info(cls):
@ -61,16 +44,6 @@ class EnterpriseService:
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
class WorkspacePermissionService:
@classmethod
def get_permission(cls, workspace_id: str):
if not workspace_id:
raise ValueError("workspace_id must be provided.")
data = EnterpriseRequest.send_request("GET", f"/workspaces/{workspace_id}/permission")
if not data or "permission" not in data:
raise ValueError("No data found.")
return WorkspacePermission.model_validate(data["permission"])
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):

View File

@ -12,6 +12,8 @@ logger = logging.getLogger(__name__)
WORKSPACE_SYNC_QUEUE = "enterprise:workspace:sync:queue"
WORKSPACE_SYNC_PROCESSING = "enterprise:workspace:sync:processing"
TASK_TYPE_SYNC_TO_WORKSPACE = "sync_to_workspace"
class WorkspaceSyncService:
"""Service to publish workspace sync tasks to Redis queue for enterprise backend consumption"""
@ -38,6 +40,7 @@ class WorkspaceSyncService:
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
"type": TASK_TYPE_SYNC_TO_WORKSPACE,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP

View File

@ -7,7 +7,6 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
from configs import dify_config
from core.file.models import File
from core.model_runtime.entities import PromptMessage
from core.variables.segments import (
ArrayFileSegment,
ArraySegment,
@ -288,10 +287,6 @@ class VariableTruncator(BaseTruncator):
if isinstance(item, File):
truncated_value.append(item)
continue
# Handle PromptMessage types - convert to dict for truncation
if isinstance(item, PromptMessage):
truncated_value.append(item)
continue
if i >= target_length:
return _PartResult(truncated_value, used_size, True)
if i > 0:

View File

@ -163,29 +163,3 @@ class WorkflowScheduleCFSPlanEntity(BaseModel):
schedule_strategy: Strategy
granularity: int = Field(default=-1) # -1 means infinite
# ========== Mention Graph Entities ==========
class MentionParameterSchema(BaseModel):
"""Schema for the parameter to be extracted from mention context."""
name: str = Field(description="Parameter name (e.g., 'query')")
type: str = Field(default="string", description="Parameter type (e.g., 'string', 'number')")
description: str = Field(default="", description="Parameter description for LLM")
class MentionGraphRequest(BaseModel):
"""Request payload for generating mention graph."""
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
parameter_key: str = Field(description="Key of the parameter being extracted")
context_source: list[str] = Field(description="Variable selector for the context source")
parameter_schema: MentionParameterSchema = Field(description="Schema of the parameter to extract")
class MentionGraphResponse(BaseModel):
"""Response containing the generated mention graph."""
graph: Mapping[str, Any] = Field(description="Complete graph structure with nodes, edges, viewport")

View File

@ -1,143 +0,0 @@
"""
Service for generating Mention LLM node graph structures.
This service creates graph structures containing LLM nodes configured for
extracting values from list[PromptMessage] variables.
"""
from typing import Any
from sqlalchemy.orm import Session
from core.model_runtime.entities import LLMMode
from core.workflow.enums import NodeType
from services.model_provider_service import ModelProviderService
from services.workflow.entities import MentionGraphRequest, MentionGraphResponse, MentionParameterSchema
class MentionGraphService:
"""Service for generating Mention LLM node graph structures."""
def __init__(self, session: Session):
self._session = session
def generate_mention_node_id(self, node_id: str, parameter_name: str) -> str:
"""Generate mention node ID following the naming convention.
Format: {node_id}_ext_{parameter_name}
"""
return f"{node_id}_ext_{parameter_name}"
def generate_mention_graph(self, tenant_id: str, request: MentionGraphRequest) -> MentionGraphResponse:
"""Generate a complete graph structure containing a Mention LLM node.
Args:
tenant_id: The tenant ID for fetching default model config
request: The mention graph generation request
Returns:
Complete graph structure with nodes, edges, and viewport
"""
node_id = self.generate_mention_node_id(request.parent_node_id, request.parameter_key)
model_config = self._get_default_model_config(tenant_id)
node = self._build_mention_llm_node(
node_id=node_id,
parent_node_id=request.parent_node_id,
context_source=request.context_source,
parameter_schema=request.parameter_schema,
model_config=model_config,
)
graph = {
"nodes": [node],
"edges": [],
"viewport": {},
}
return MentionGraphResponse(graph=graph)
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
"""Get the default LLM model configuration for the tenant."""
model_provider_service = ModelProviderService()
default_model = model_provider_service.get_default_model_of_model_type(
tenant_id=tenant_id,
model_type="llm",
)
if default_model:
return {
"provider": default_model.provider.provider,
"name": default_model.model,
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
# Fallback to empty config if no default model is configured
return {
"provider": "",
"name": "",
"mode": LLMMode.CHAT.value,
"completion_params": {},
}
def _build_mention_llm_node(
self,
*,
node_id: str,
parent_node_id: str,
context_source: list[str],
parameter_schema: MentionParameterSchema,
model_config: dict[str, Any],
) -> dict[str, Any]:
"""Build the Mention LLM node structure.
The node uses:
- $context in prompt_template to reference the PromptMessage list
- structured_output for extracting the specific parameter
- parent_node_id to associate with the parent node
"""
prompt_template = [
{
"role": "system",
"text": "Extract the required parameter value from the conversation context above.",
},
{"$context": context_source},
{"role": "user", "text": ""},
]
structured_output = {
"schema": {
"type": "object",
"properties": {
parameter_schema.name: {
"type": parameter_schema.type,
"description": parameter_schema.description,
}
},
"required": [parameter_schema.name],
"additionalProperties": False,
}
}
return {
"id": node_id,
"position": {"x": 0, "y": 0},
"data": {
"type": NodeType.LLM.value,
"title": f"Mention: {parameter_schema.name}",
"desc": f"Extract {parameter_schema.name} from conversation context",
"parent_node_id": parent_node_id,
"model": model_config,
"prompt_template": prompt_template,
"context": {
"enabled": False,
"variable_selector": None,
},
"vision": {
"enabled": False,
},
"memory": None,
"structured_output_enabled": True,
"structured_output": structured_output,
},
}

View File

@ -83,30 +83,7 @@
<p class="content1">Dear {{ to }},</p>
<p class="content2">{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.</p>
<p class="content2">Click the button below to log in to Dify and join the workspace.</p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Login Here</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">Login Here</a></p>
<p class="content2">Best regards,</p>
<p class="content2">Dify Team</p>
</div>

View File

@ -83,30 +83,7 @@
<p class="content1">尊敬的 {{ to }}</p>
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 Dify 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 Dify 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p class="content2">点击下方按钮即可登录 Dify 并且加入空间。</p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">在此登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
<p class="content2">此致,</p>
<p class="content2">Dify 团队</p>
</div>

View File

@ -115,30 +115,7 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here: </p>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Log In</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<a href="{{ login_url }}" class="button">Log In</a>
<p class="description">
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
class="reset-btn">Reset Password</a>

View File

@ -115,30 +115,7 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录: </p>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<a href="{{ login_url }}" class="button">登录</a>
<p class="description">
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
</p>

View File

@ -92,34 +92,12 @@
platform specifically designed for LLM application development. On {{application_title}}, you can explore,
create, and collaborate to build and operate AI applications.</p>
<p class="content2">Click the button below to log in to {{application_title}} and join the workspace.</p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Login Here</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none"
class="button" href="{{ url }}">Login Here</a></p>
<p class="content2">Best regards,</p>
<p class="content2">{{application_title}} Team</p>
</div>
</div>
</body>
</html>
</html>

View File

@ -81,30 +81,7 @@
<p class="content1">尊敬的 {{ to }}</p>
<p class="content2">{{ inviter_name }} 现邀请您加入我们在 {{application_title}} 的工作区,这是一个专为 LLM 应用开发而设计的平台。在 {{application_title}} 上,您可以探索、创造和合作,构建和运营 AI 应用。</p>
<p class="content2">点击下方按钮即可登录 {{application_title}} 并且加入空间。</p>
<div style="text-align: center; margin-bottom: 32px;">
<a href="{{ url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">在此登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ url }}
</a>
</p>
</div>
<p style="text-align: center; margin: 0; margin-bottom: 32px;"><a style="color: #fff; text-decoration: none" class="button" href="{{ url }}">在此登录</a></p>
<p class="content2">此致,</p>
<p class="content2">{{application_title}} 团队</p>
</div>

View File

@ -111,30 +111,7 @@
We noticed you tried to sign up, but this email is already registered with an existing account.
Please log in here: </p>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">Log In</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
If the button doesn't work, copy and paste this link into your browser:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<a href="{{ login_url }}" class="button">Log In</a>
<p class="description">
If you forgot your password, you can reset it here: <a href="{{ reset_password_url }}"
class="reset-btn">Reset Password</a>

View File

@ -111,30 +111,7 @@
我们注意到您尝试注册,但此电子邮件已注册。
请在此登录: </p>
<div style="text-align: center; margin-bottom: 20px;">
<a href="{{ login_url }}"
style="background-color:#2563eb;
color:#ffffff !important;
text-decoration:none;
display:inline-block;
font-weight:600;
border-radius:4px;
font-size:14px;
line-height:18px;
font-family: Helvetica, Arial, sans-serif;
text-align:center;
border-top: 10px solid #2563eb;
border-bottom: 10px solid #2563eb;
border-left: 20px solid #2563eb;
border-right: 20px solid #2563eb;
">登录</a>
<p style="font-size: 12px; color: #666666; margin-top: 20px; margin-bottom: 0;">
如果按钮无法使用,请将以下链接复制到浏览器打开:<br>
<a href="{{ login_url }}" style="color: #2563eb; text-decoration: underline; word-break: break-all;">
{{ login_url }}
</a>
</p>
</div>
<a href="{{ login_url }}" class="button">登录</a>
<p class="description">
如果您忘记了密码,可以在此重置: <a href="{{ reset_password_url }}" class="reset-btn">重置密码</a>
</p>

View File

@ -1,181 +0,0 @@
app:
description: ''
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: file output schema
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- remote_url
- local_file
enabled: true
fileUploadConfig:
attachment_image_file_size_limit: 2
audio_file_size_limit: 50
batch_count_limit: 5
file_size_limit: 15
file_upload_limit: 10
image_file_batch_limit: 10
image_file_size_limit: 10
single_chunk_attachment_limit: 10
video_file_size_limit: 100
workflow_file_upload_limit: 10
number_limits: 3
opening_statement: ''
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1768292241666-llm
source: '1768292241666'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
sourceType: llm
targetType: answer
id: llm-answer
source: llm
sourceHandle: source
target: answer
targetHandle: target
type: custom
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1768292241666'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
query_prompt_template: '{{#sys.query#}}
{{#sys.files#}}'
role_prefix:
assistant: ''
user: ''
window:
enabled: false
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: gpt-4o-mini
provider: langgenius/openai/openai
prompt_template:
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
role: system
text: ''
selected: false
structured_output:
schema:
additionalProperties: false
properties:
image:
description: File ID (UUID) of the selected image
format: dify-file-ref
type: string
required:
- image
type: object
structured_output_enabled: true
title: LLM
type: llm
vision:
configs:
detail: high
variable_selector:
- sys
- files
enabled: true
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '{{#llm.structured_output.image#}}'
selected: false
title: Answer
type: answer
variables: []
height: 103
id: answer
position:
x: 680
y: 282
positionAbsolute:
x: 680
y: 282
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -149
y: 97.5
zoom: 1
rag_pipeline_variables: []

View File

@ -1,307 +0,0 @@
app:
description: Test for variable extraction feature
icon: 🤖
icon_background: '#FFEAD5'
mode: advanced-chat
name: pav-test-extraction
use_icon_as_answer_icon: false
dependencies:
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
version: null
- current_identifier: null
type: marketplace
value:
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
version: null
kind: app
version: 0.5.0
workflow:
conversation_variables: []
environment_variables: []
features:
file_upload:
allowed_file_extensions:
- .JPG
- .JPEG
- .PNG
- .GIF
- .WEBP
- .SVG
allowed_file_types:
- image
allowed_file_upload_methods:
- local_file
- remote_url
enabled: false
image:
enabled: false
number_limits: 3
transfer_methods:
- local_file
- remote_url
number_limits: 3
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
retriever_resource:
enabled: true
sensitive_word_avoidance:
enabled: false
speech_to_text:
enabled: false
suggested_questions: []
suggested_questions_after_answer:
enabled: false
text_to_speech:
enabled: false
language: ''
voice: ''
graph:
edges:
- data:
sourceType: start
targetType: llm
id: 1767773675796-llm
source: '1767773675796'
sourceHandle: source
target: llm
targetHandle: target
type: custom
- data:
isInIteration: false
isInLoop: false
sourceType: llm
targetType: tool
id: llm-source-1767773709491-target
source: llm
sourceHandle: source
target: '1767773709491'
targetHandle: target
type: custom
zIndex: 0
- data:
isInIteration: false
isInLoop: false
sourceType: tool
targetType: answer
id: tool-source-answer-target
source: '1767773709491'
sourceHandle: source
target: answer
targetHandle: target
type: custom
zIndex: 0
nodes:
- data:
selected: false
title: User Input
type: start
variables: []
height: 73
id: '1767773675796'
position:
x: 80
y: 282
positionAbsolute:
x: 80
y: 282
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
memory:
mode: node
query_prompt_template: '{{#sys.query#}}'
role_prefix:
assistant: ''
user: ''
window:
enabled: true
size: 10
model:
completion_params:
temperature: 0.7
mode: chat
name: qwen-max
provider: langgenius/tongyi/tongyi
prompt_template:
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
role: system
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
请与用户进行对话,了解他们的搜索需求。
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
'
selected: false
title: LLM
type: llm
vision:
enabled: false
height: 88
id: llm
position:
x: 380
y: 282
positionAbsolute:
x: 380
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
is_team_authorization: true
paramSchemas:
- auto_generate: null
default: null
form: llm
human_description:
en_US: used for searching
ja_JP: used for searching
pt_BR: used for searching
zh_Hans: 用于搜索网页内容
label:
en_US: Query string
ja_JP: Query string
pt_BR: Query string
zh_Hans: 查询语句
llm_description: key words for searching
max: null
min: null
name: query
options: []
placeholder: null
precision: null
required: true
scope: null
template: null
type: string
params:
query: ''
plugin_id: langgenius/google
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
provider_id: langgenius/google/google
provider_name: langgenius/google/google
provider_type: builtin
selected: false
title: GoogleSearch
tool_configurations: {}
tool_description: A tool for performing a Google SERP search and extracting
snippets and webpages.Input should be a search query.
tool_label: GoogleSearch
tool_name: google_search
tool_node_version: '2'
tool_parameters:
query:
type: mention
value: '{{@llm.context@}}请从对话历史中提取用户想要搜索的关键词,只返回关键词本身'
mention_config:
extractor_node_id: 1767773709491_ext_query
output_selector:
- structured_output
- query
null_strategy: use_default
default_value: ''
type: tool
height: 52
id: '1767773709491'
position:
x: 682
y: 282
positionAbsolute:
x: 682
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
context:
enabled: false
variable_selector: []
model:
completion_params:
temperature: 0.7
mode: chat
name: gpt-4o-mini
provider: langgenius/openai/openai
parent_node_id: '1767773709491'
prompt_template:
- $context:
- llm
- context
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
- id: 18ba6710-77f5-47f4-b144-9191833bb547
role: user
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
selected: false
structured_output:
schema:
additionalProperties: false
properties:
query:
description: 搜索的关键词
type: string
required:
- query
type: object
structured_output_enabled: true
title: 提取搜索关键词
type: llm
vision:
enabled: false
height: 88
id: 1767773709491_ext_query
position:
x: 531
y: 382
positionAbsolute:
x: 531
y: 382
selected: true
sourcePosition: right
targetPosition: left
type: custom
width: 242
- data:
answer: '搜索结果:
{{#1767773709491.text#}}
'
selected: false
title: Answer
type: answer
height: 103
id: answer
position:
x: 984
y: 282
positionAbsolute:
x: 984
y: 282
selected: false
sourcePosition: right
targetPosition: left
type: custom
width: 242
viewport:
x: -151
y: 123
zoom: 1
rag_pipeline_variables: []

View File

@ -1,182 +0,0 @@
"""Tests for file_manager module, specifically multimodal content handling."""
from unittest.mock import patch
from core.file import File, FileTransferMethod, FileType
from core.file.file_manager import (
_encode_file_ref,
restore_multimodal_content,
to_prompt_message_content,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
class TestEncodeFileRef:
"""Tests for _encode_file_ref function."""
def test_encodes_local_file(self):
"""Local file should be encoded as 'local:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="abc123",
storage_key="key",
)
assert _encode_file_ref(file) == "local:abc123"
def test_encodes_tool_file(self):
"""Tool file should be encoded as 'tool:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="xyz789",
storage_key="key",
)
assert _encode_file_ref(file) == "tool:xyz789"
def test_encodes_remote_url(self):
"""Remote URL should be encoded as 'remote:url'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.png",
storage_key="",
)
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
class TestToPromptMessageContent:
"""Tests for to_prompt_message_content function with file_ref field."""
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._get_encoded_string")
def test_includes_file_ref(self, mock_get_encoded, mock_config):
"""Generated content should include file_ref field."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_get_encoded.return_value = "base64data"
file = File(
id="test-message-file-id",
tenant_id="test-tenant",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-related-id",
remote_url=None,
extension=".png",
mime_type="image/png",
filename="test.png",
storage_key="test-key",
)
result = to_prompt_message_content(file)
assert isinstance(result, ImagePromptMessageContent)
assert result.file_ref == "local:test-related-id"
assert result.base64_data == "base64data"
class TestRestoreMultimodalContent:
"""Tests for restore_multimodal_content function."""
def test_returns_content_unchanged_when_no_file_ref(self):
"""Content without file_ref should pass through unchanged."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref=None,
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_data(self):
"""Content that already has base64_data should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_url(self):
"""Content that already has url should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://example.com/image.png"
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._to_url")
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
"""Content should be restored from file_ref when url is empty (url mode)."""
mock_config.MULTIMODAL_SEND_FORMAT = "url"
mock_build_file.return_value = "mock_file"
mock_to_url.return_value = "https://restored-url.com/image.png"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://restored-url.com/image.png"
mock_build_file.assert_called_once()
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._get_encoded_string")
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
"""Content should be restored as base64 when in base64 mode."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_build_file.return_value = "mock_file"
mock_get_encoded.return_value = "restored-base64-data"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "restored-base64-data"
mock_build_file.assert_called_once()
def test_handles_invalid_file_ref_gracefully(self):
"""Invalid file_ref format should be handled gracefully."""
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
file_ref="invalid_format_no_colon",
)
result = restore_multimodal_content(content)
# Should return unchanged on error
assert result.base64_data == ""

View File

@ -1,269 +0,0 @@
"""
Unit tests for file reference detection and conversion.
"""
import uuid
from unittest.mock import MagicMock, patch
import pytest
from core.file import File, FileTransferMethod, FileType
from core.llm_generator.output_parser.file_ref import (
FILE_REF_FORMAT,
convert_file_refs_in_output,
detect_file_ref_fields,
is_file_ref_property,
)
from core.variables.segments import ArrayFileSegment, FileSegment
class TestIsFileRefProperty:
"""Tests for is_file_ref_property function."""
def test_valid_file_ref(self):
schema = {"type": "string", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is True
def test_invalid_type(self):
schema = {"type": "number", "format": FILE_REF_FORMAT}
assert is_file_ref_property(schema) is False
def test_missing_format(self):
schema = {"type": "string"}
assert is_file_ref_property(schema) is False
def test_wrong_format(self):
schema = {"type": "string", "format": "uuid"}
assert is_file_ref_property(schema) is False
class TestDetectFileRefFields:
"""Tests for detect_file_ref_fields function."""
def test_simple_file_ref(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["image"]
def test_multiple_file_refs(self):
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
"document": {"type": "string", "format": FILE_REF_FORMAT},
"name": {"type": "string"},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "document"}
def test_array_of_file_refs(self):
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["files[*]"]
def test_nested_file_ref(self):
schema = {
"type": "object",
"properties": {
"data": {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
},
},
}
paths = detect_file_ref_fields(schema)
assert paths == ["data.image"]
def test_no_file_refs(self):
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_empty_schema(self):
schema = {}
paths = detect_file_ref_fields(schema)
assert paths == []
def test_mixed_schema(self):
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"documents": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
paths = detect_file_ref_fields(schema)
assert set(paths) == {"image", "documents[*]"}
class TestConvertFileRefsInOutput:
"""Tests for convert_file_refs_in_output function."""
@pytest.fixture
def mock_file(self):
"""Create a mock File object with all required attributes."""
file = MagicMock(spec=File)
file.type = FileType.IMAGE
file.transfer_method = FileTransferMethod.TOOL_FILE
file.related_id = "test-related-id"
file.remote_url = None
file.tenant_id = "tenant_123"
file.id = None
file.filename = "test.png"
file.extension = ".png"
file.mime_type = "image/png"
file.size = 1024
file.dify_model_identity = "__dify__file__"
return file
@pytest.fixture
def mock_build_from_mapping(self, mock_file):
"""Mock the build_from_mapping function."""
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.return_value = mock_file
yield mock
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in FileSegment
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
mock_build_from_mapping.assert_called_once_with(
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
tenant_id="tenant_123",
)
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
file_id1 = str(uuid.uuid4())
file_id2 = str(uuid.uuid4())
output = {"files": [file_id1, file_id2]}
schema = {
"type": "object",
"properties": {
"files": {
"type": "array",
"items": {"type": "string", "format": FILE_REF_FORMAT},
},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
# Result should be wrapped in ArrayFileSegment
assert isinstance(result["files"], ArrayFileSegment)
assert list(result["files"].value) == [mock_file, mock_file]
assert mock_build_from_mapping.call_count == 2
def test_no_conversion_without_file_refs(self):
output = {"name": "test", "count": 5}
schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result == {"name": "test", "count": 5}
def test_invalid_uuid_returns_none(self):
output = {"image": "not-a-valid-uuid"}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_file_not_found_returns_none(self):
file_id = str(uuid.uuid4())
output = {"image": file_id}
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
mock.side_effect = ValueError("File not found")
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["image"] is None
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
output = {"query": "search term", "image": file_id, "count": 10}
schema = {
"type": "object",
"properties": {
"query": {"type": "string"},
"image": {"type": "string", "format": FILE_REF_FORMAT},
"count": {"type": "number"},
},
}
result = convert_file_refs_in_output(output, schema, "tenant_123")
assert result["query"] == "search term"
assert isinstance(result["image"], FileSegment)
assert result["image"].value == mock_file
assert result["count"] == 10
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
file_id = str(uuid.uuid4())
original = {"image": file_id}
output = dict(original)
schema = {
"type": "object",
"properties": {
"image": {"type": "string", "format": FILE_REF_FORMAT},
},
}
convert_file_refs_in_output(output, schema, "tenant_123")
# Original should still contain the string ID
assert original["image"] == file_id

View File

@ -1 +0,0 @@
"""Tests for workflow context management."""

View File

@ -1,258 +0,0 @@
"""Tests for execution context module."""
import contextvars
from typing import Any
from unittest.mock import MagicMock
import pytest
from core.workflow.context.execution_context import (
AppContext,
ExecutionContext,
ExecutionContextBuilder,
IExecutionContext,
NullAppContext,
)
class TestAppContext:
"""Test AppContext abstract base class."""
def test_app_context_is_abstract(self):
"""Test that AppContext cannot be instantiated directly."""
with pytest.raises(TypeError):
AppContext() # type: ignore
class TestNullAppContext:
"""Test NullAppContext implementation."""
def test_null_app_context_get_config(self):
"""Test get_config returns value from config dict."""
config = {"key1": "value1", "key2": "value2"}
ctx = NullAppContext(config=config)
assert ctx.get_config("key1") == "value1"
assert ctx.get_config("key2") == "value2"
def test_null_app_context_get_config_default(self):
"""Test get_config returns default when key not found."""
ctx = NullAppContext()
assert ctx.get_config("nonexistent", "default") == "default"
assert ctx.get_config("nonexistent") is None
def test_null_app_context_get_extension(self):
"""Test get_extension returns stored extension."""
ctx = NullAppContext()
extension = MagicMock()
ctx.set_extension("db", extension)
assert ctx.get_extension("db") == extension
def test_null_app_context_get_extension_not_found(self):
"""Test get_extension returns None when extension not found."""
ctx = NullAppContext()
assert ctx.get_extension("nonexistent") is None
def test_null_app_context_enter_yield(self):
"""Test enter method yields without any side effects."""
ctx = NullAppContext()
with ctx.enter():
# Should not raise any exception
pass
class TestExecutionContext:
"""Test ExecutionContext class."""
def test_initialization_with_all_params(self):
"""Test ExecutionContext initialization with all parameters."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = ExecutionContext(
app_context=app_ctx,
context_vars=context_vars,
user=user,
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_initialization_with_minimal_params(self):
"""Test ExecutionContext initialization with minimal parameters."""
ctx = ExecutionContext()
assert ctx.app_context is None
assert ctx.context_vars is None
assert ctx.user is None
def test_enter_with_context_vars(self):
"""Test enter restores context variables."""
test_var = contextvars.ContextVar("test_var")
test_var.set("original_value")
# Copy context with the variable
context_vars = contextvars.copy_context()
# Change the variable
test_var.set("new_value")
# Create execution context and enter it
ctx = ExecutionContext(context_vars=context_vars)
with ctx.enter():
# Variable should be restored to original value
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_with_app_context(self):
"""Test enter enters app context if available."""
app_ctx = NullAppContext()
ctx = ExecutionContext(app_context=app_ctx)
# Should not raise any exception
with ctx.enter():
pass
def test_enter_without_app_context(self):
"""Test enter works without app context."""
ctx = ExecutionContext(app_context=None)
# Should not raise any exception
with ctx.enter():
pass
def test_context_manager_protocol(self):
"""Test ExecutionContext supports context manager protocol."""
ctx = ExecutionContext()
with ctx:
# Should not raise any exception
pass
def test_user_property(self):
"""Test user property returns set user."""
user = MagicMock()
ctx = ExecutionContext(user=user)
assert ctx.user == user
class TestIExecutionContextProtocol:
"""Test IExecutionContext protocol."""
def test_execution_context_implements_protocol(self):
"""Test that ExecutionContext implements IExecutionContext protocol."""
ctx = ExecutionContext()
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
def test_protocol_compatibility(self):
"""Test that ExecutionContext can be used where IExecutionContext is expected."""
def accept_context(context: IExecutionContext) -> Any:
"""Function that accepts IExecutionContext protocol."""
# Just verify it has the required protocol attributes
assert hasattr(context, "__enter__")
assert hasattr(context, "__exit__")
assert hasattr(context, "user")
return context.user
ctx = ExecutionContext(user="test_user")
result = accept_context(ctx)
assert result == "test_user"
def test_protocol_with_flask_execution_context(self):
"""Test that IExecutionContext protocol is compatible with different implementations."""
# Verify the protocol works with ExecutionContext
ctx = ExecutionContext(user="test_user")
# Should have the required protocol attributes
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
assert hasattr(ctx, "user")
assert ctx.user == "test_user"
# Should work as context manager
with ctx:
assert ctx.user == "test_user"
class TestExecutionContextBuilder:
"""Test ExecutionContextBuilder class."""
def test_builder_with_all_params(self):
"""Test builder with all parameters set."""
app_ctx = NullAppContext()
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = (
ExecutionContextBuilder().with_app_context(app_ctx).with_context_vars(context_vars).with_user(user).build()
)
assert ctx.app_context == app_ctx
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_builder_with_partial_params(self):
"""Test builder with only some parameters set."""
app_ctx = NullAppContext()
ctx = ExecutionContextBuilder().with_app_context(app_ctx).build()
assert ctx.app_context == app_ctx
assert ctx.context_vars is None
assert ctx.user is None
def test_builder_fluent_interface(self):
"""Test builder provides fluent interface."""
builder = ExecutionContextBuilder()
# Each method should return the builder
assert isinstance(builder.with_app_context(NullAppContext()), ExecutionContextBuilder)
assert isinstance(builder.with_context_vars(contextvars.copy_context()), ExecutionContextBuilder)
assert isinstance(builder.with_user(None), ExecutionContextBuilder)
class TestCaptureCurrentContext:
"""Test capture_current_context function."""
def test_capture_current_context_returns_context(self):
"""Test that capture_current_context returns a valid context."""
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Should return an object that implements IExecutionContext
assert hasattr(result, "__enter__")
assert hasattr(result, "__exit__")
assert hasattr(result, "user")
def test_capture_current_context_captures_contextvars(self):
"""Test that capture_current_context captures context variables."""
# Set a context variable before capturing
import contextvars
test_var = contextvars.ContextVar("capture_test_var")
test_var.set("test_value_123")
from core.workflow.context.execution_context import capture_current_context
result = capture_current_context()
# Context variables should be captured
assert result.context_vars is not None

View File

@ -1,316 +0,0 @@
"""Tests for Flask app context module."""
import contextvars
from unittest.mock import MagicMock, patch
import pytest
class TestFlaskAppContext:
"""Test FlaskAppContext implementation."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {"TEST_KEY": "test_value"}
app.extensions = {"db": MagicMock(), "cache": MagicMock()}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_flask_app_context_initialization(self, mock_flask_app):
"""Test FlaskAppContext initialization."""
# Import here to avoid Flask dependency in test environment
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.flask_app == mock_flask_app
def test_flask_app_context_get_config(self, mock_flask_app):
"""Test get_config returns Flask app config value."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("TEST_KEY") == "test_value"
def test_flask_app_context_get_config_default(self, mock_flask_app):
"""Test get_config returns default when key not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_config("NONEXISTENT", "default") == "default"
def test_flask_app_context_get_extension(self, mock_flask_app):
"""Test get_extension returns Flask extension."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
db_ext = mock_flask_app.extensions["db"]
assert ctx.get_extension("db") == db_ext
def test_flask_app_context_get_extension_not_found(self, mock_flask_app):
"""Test get_extension returns None when extension not found."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
assert ctx.get_extension("nonexistent") is None
def test_flask_app_context_enter(self, mock_flask_app):
"""Test enter method enters Flask app context."""
from context.flask_app_context import FlaskAppContext
ctx = FlaskAppContext(mock_flask_app)
with ctx.enter():
# Should not raise any exception
pass
# Verify app_context was called
mock_flask_app.app_context.assert_called_once()
class TestFlaskExecutionContext:
"""Test FlaskExecutionContext class."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app."""
app = MagicMock()
app.config = {}
app.app_context = MagicMock()
app.app_context.return_value.__enter__ = MagicMock(return_value=None)
app.app_context.return_value.__exit__ = MagicMock(return_value=None)
return app
def test_initialization(self, mock_flask_app):
"""Test FlaskExecutionContext initialization."""
from context.flask_app_context import FlaskExecutionContext
context_vars = contextvars.copy_context()
user = MagicMock()
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
user=user,
)
assert ctx.context_vars == context_vars
assert ctx.user == user
def test_app_context_property(self, mock_flask_app):
"""Test app_context property returns FlaskAppContext."""
from context.flask_app_context import FlaskAppContext, FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
assert isinstance(ctx.app_context, FlaskAppContext)
assert ctx.app_context.flask_app == mock_flask_app
def test_context_manager_protocol(self, mock_flask_app):
"""Test FlaskExecutionContext supports context manager protocol."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# Should have __enter__ and __exit__ methods
assert hasattr(ctx, "__enter__")
assert hasattr(ctx, "__exit__")
# Should work as context manager
with ctx:
pass
class TestCaptureFlaskContext:
"""Test capture_flask_context function."""
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_app(self, mock_g, mock_current_app):
"""Test capture_flask_context captures Flask app."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx._flask_app == mock_app
@patch("context.flask_app_context.current_app")
@patch("context.flask_app_context.g")
def test_capture_flask_context_captures_user_from_g(self, mock_g, mock_current_app):
"""Test capture_flask_context captures user from Flask g object."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
mock_user = MagicMock()
mock_user.id = "user_123"
mock_g._login_user = mock_user
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
assert ctx.user == mock_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_with_explicit_user(self, mock_current_app):
"""Test capture_flask_context uses explicit user parameter."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
explicit_user = MagicMock()
explicit_user.id = "user_456"
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context(user=explicit_user)
assert ctx.user == explicit_user
@patch("context.flask_app_context.current_app")
def test_capture_flask_context_captures_contextvars(self, mock_current_app):
"""Test capture_flask_context captures context variables."""
mock_app = MagicMock()
mock_app._get_current_object = MagicMock(return_value=mock_app)
mock_current_app._get_current_object = MagicMock(return_value=mock_app)
# Set a context variable
test_var = contextvars.ContextVar("test_var")
test_var.set("test_value")
from context.flask_app_context import capture_flask_context
ctx = capture_flask_context()
# Context variables should be captured
assert ctx.context_vars is not None
# Verify the variable is in the captured context
captured_value = ctx.context_vars[test_var]
assert captured_value == "test_value"
class TestFlaskExecutionContextIntegration:
"""Integration tests for FlaskExecutionContext."""
@pytest.fixture
def mock_flask_app(self):
"""Create a mock Flask app with proper app context."""
app = MagicMock()
app.config = {"TEST": "value"}
app.extensions = {"db": MagicMock()}
# Mock app context
mock_app_context = MagicMock()
mock_app_context.__enter__ = MagicMock(return_value=None)
mock_app_context.__exit__ = MagicMock(return_value=None)
app.app_context.return_value = mock_app_context
return app
def test_enter_restores_context_vars(self, mock_flask_app):
"""Test that enter restores captured context variables."""
# Create a context variable and set a value
test_var = contextvars.ContextVar("integration_test_var")
test_var.set("original_value")
# Capture the context
context_vars = contextvars.copy_context()
# Change the value
test_var.set("new_value")
# Create FlaskExecutionContext and enter it
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=context_vars,
)
with ctx:
# Value should be restored to original
assert test_var.get() == "original_value"
# After exiting, variable stays at the value from within the context
# (this is expected Python contextvars behavior)
assert test_var.get() == "original_value"
def test_enter_enters_flask_app_context(self, mock_flask_app):
"""Test that enter enters Flask app context."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
with ctx:
# Verify app context was entered
assert mock_flask_app.app_context.called
@patch("context.flask_app_context.g")
def test_enter_restores_user_in_g(self, mock_g, mock_flask_app):
"""Test that enter restores user in Flask g object."""
mock_user = MagicMock()
mock_user.id = "test_user"
# Note: FlaskExecutionContext saves user from g before entering context,
# then restores it after entering the app context.
# The user passed to constructor is NOT restored to g.
# So we need to test the actual behavior.
# Create FlaskExecutionContext with user in constructor
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
user=mock_user,
)
# Set user in g before entering (simulating existing user in g)
mock_g._login_user = mock_user
with ctx:
# After entering, the user from g before entry should be restored
assert mock_g._login_user == mock_user
# The user in constructor is stored but not automatically restored to g
# (it's available via ctx.user property)
assert ctx.user == mock_user
def test_enter_method_as_context_manager(self, mock_flask_app):
"""Test enter method returns a proper context manager."""
from context.flask_app_context import FlaskExecutionContext
ctx = FlaskExecutionContext(
flask_app=mock_flask_app,
context_vars=contextvars.copy_context(),
)
# enter() should return a generator/context manager
with ctx.enter():
# Should work without issues
pass
# Verify app context was called
assert mock_flask_app.app_context.called

View File

@ -25,12 +25,6 @@ class _StubErrorHandler:
"""Minimal error handler stub for tests."""
class _StubNodeData:
"""Simple node data stub with is_extractor_node property."""
is_extractor_node = False
class _StubNode:
"""Simple node stub exposing the attributes needed by the state manager."""
@ -42,7 +36,6 @@ class _StubNode:
self.error_strategy = None
self.retry_config = RetryConfig()
self.retry = False
self.node_data = _StubNodeData()
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:

View File

@ -1,174 +0,0 @@
"""Tests for llm_utils module, specifically multimodal content handling."""
import string
from unittest.mock import patch
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from core.workflow.nodes.llm.llm_utils import (
_truncate_multimodal_content,
build_context,
restore_multimodal_content_in_messages,
)
class TestTruncateMultimodalContent:
"""Tests for _truncate_multimodal_content function."""
def test_returns_message_unchanged_for_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello, world!")
result = _truncate_multimodal_content(message)
assert result.content == "Hello, world!"
def test_returns_message_unchanged_for_none_content(self):
"""None content should pass through unchanged."""
message = UserPromptMessage(content=None)
result = _truncate_multimodal_content(message)
assert result.content is None
def test_clears_base64_when_file_ref_present(self):
"""When file_ref is present, base64_data and url should be cleared."""
image_content = ImagePromptMessageContent(
format="png",
base64_data=string.ascii_lowercase,
url="https://example.com/image.png",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 1
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
assert result_content.base64_data == ""
assert result_content.url == ""
# file_ref should be preserved
assert result_content.file_ref == "local:test-file-id"
def test_truncates_base64_when_no_file_ref(self):
"""When file_ref is missing (legacy), base64_data should be truncated."""
long_base64 = "a" * 100
image_content = ImagePromptMessageContent(
format="png",
base64_data=long_base64,
mime_type="image/png",
filename="test.png",
file_ref=None,
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
# Should be truncated with marker
assert "...[TRUNCATED]..." in result_content.base64_data
assert len(result_content.base64_data) < len(long_base64)
def test_preserves_text_content(self):
"""Text content should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
image_content = ImagePromptMessageContent(
format="png",
base64_data="test123",
mime_type="image/png",
file_ref="local:file-id",
)
message = UserPromptMessage(content=[text_content, image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 2
# Text content unchanged
assert result.content[0].data == "Hello!"
# Image content base64 cleared
assert result.content[1].base64_data == ""
class TestBuildContext:
"""Tests for build_context function."""
def test_excludes_system_messages(self):
"""System messages should be excluded from context."""
from core.model_runtime.entities.message_entities import SystemPromptMessage
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello!"),
]
context = build_context(messages, "Hi there!")
# Should have user message + assistant response, no system message
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
def test_appends_assistant_response(self):
"""Assistant response should be appended to context."""
messages = [UserPromptMessage(content="What is 2+2?")]
context = build_context(messages, "The answer is 4.")
assert len(context) == 2
assert context[1].content == "The answer is 4."
class TestRestoreMultimodalContentInMessages:
"""Tests for restore_multimodal_content_in_messages function."""
@patch("core.file.file_manager.restore_multimodal_content")
def test_restores_multimodal_content(self, mock_restore):
"""Should restore multimodal content in messages."""
# Setup mock
restored_content = ImagePromptMessageContent(
format="png",
base64_data="restored-base64",
mime_type="image/png",
file_ref="local:abc123",
)
mock_restore.return_value = restored_content
# Create message with truncated content
truncated_content = ImagePromptMessageContent(
format="png",
base64_data="",
mime_type="image/png",
file_ref="local:abc123",
)
message = UserPromptMessage(content=[truncated_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].base64_data == "restored-base64"
mock_restore.assert_called_once()
def test_passes_through_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello!")
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content == "Hello!"
def test_passes_through_text_content(self):
"""TextPromptMessageContent should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
message = UserPromptMessage(content=[text_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].data == "Hello!"

View File

@ -1,142 +0,0 @@
from unittest.mock import Mock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.workspace_permission import (
check_workspace_member_invite_permission,
check_workspace_owner_transfer_permission,
)
class TestWorkspacePermissionHelper:
"""Test workspace permission helper functions."""
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.EnterpriseService")
def test_community_edition_allows_invite(self, mock_enterprise_service, mock_config):
"""Community edition should always allow invitations without calling any service."""
mock_config.ENTERPRISE_ENABLED = False
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
# EnterpriseService should NOT be called in community edition
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_community_edition_allows_transfer(self, mock_feature_service, mock_config):
"""Community edition should check billing plan but not call enterprise service."""
mock_config.ENTERPRISE_ENABLED = False
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True
mock_feature_service.get_features.return_value = mock_features
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_feature_service.get_features.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_blocks_invite_when_disabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should block invitations when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = False
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits member invitations"):
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_allows_invite_when_enabled(self, mock_config, mock_enterprise_service):
"""Enterprise edition should allow invitations when workspace policy is True."""
mock_config.ENTERPRISE_ENABLED = True
mock_permission = Mock()
mock_permission.allow_member_invite = True
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_member_invite_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_billing_plan_blocks_transfer(self, mock_feature_service, mock_config, mock_enterprise_service):
"""SANDBOX billing plan should block owner transfer before checking enterprise policy."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = False # SANDBOX plan
mock_feature_service.get_features.return_value = mock_features
with pytest.raises(Forbidden, match="Your current plan does not allow workspace ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
# Enterprise service should NOT be called since billing plan already blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_not_called()
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_blocks_transfer_when_disabled(self, mock_feature_service, mock_config, mock_enterprise_service):
"""Enterprise edition should block transfer when workspace policy is False."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = False # Workspace policy blocks
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
with pytest.raises(Forbidden, match="Workspace policy prohibits ownership transfer"):
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
@patch("libs.workspace_permission.FeatureService")
def test_enterprise_allows_transfer_when_both_enabled(
self, mock_feature_service, mock_config, mock_enterprise_service
):
"""Enterprise edition should allow transfer when both billing and workspace policy allow."""
mock_config.ENTERPRISE_ENABLED = True
mock_features = Mock()
mock_features.is_allow_transfer_workspace = True # Billing plan allows
mock_feature_service.get_features.return_value = mock_features
mock_permission = Mock()
mock_permission.allow_owner_transfer = True # Workspace policy allows
mock_enterprise_service.WorkspacePermissionService.get_permission.return_value = mock_permission
# Should not raise
check_workspace_owner_transfer_permission("test-workspace-id")
mock_enterprise_service.WorkspacePermissionService.get_permission.assert_called_once_with("test-workspace-id")
@patch("libs.workspace_permission.logger")
@patch("libs.workspace_permission.EnterpriseService")
@patch("libs.workspace_permission.dify_config")
def test_enterprise_service_error_fails_open(self, mock_config, mock_enterprise_service, mock_logger):
"""On enterprise service error, should fail-open (allow) and log error."""
mock_config.ENTERPRISE_ENABLED = True
# Simulate enterprise service error
mock_enterprise_service.WorkspacePermissionService.get_permission.side_effect = Exception("Service unavailable")
# Should not raise (fail-open)
check_workspace_member_invite_permission("test-workspace-id")
# Should log the error
mock_logger.exception.assert_called_once()
assert "Failed to check workspace invite permission" in str(mock_logger.exception.call_args)

View File

@ -2,13 +2,9 @@ from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from pydantic import BaseModel, ConfigDict
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import (
invoke_llm_with_pydantic_model,
invoke_llm_with_structured_output,
)
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
@ -465,68 +461,3 @@ def test_model_specific_schema_preparation():
# For Gemini, the schema should not have additionalProperties and boolean should be converted to string
assert "json_schema" in call_args.kwargs["model_parameters"]
class ExampleOutput(BaseModel):
model_config = ConfigDict(extra="forbid")
name: str
def test_structured_output_with_pydantic_model():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": "test"}'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
)
prompt_messages = [UserPromptMessage(content="Return a JSON object with name.")]
result = invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=prompt_messages,
output_model=ExampleOutput,
stream=False,
)
assert isinstance(result, LLMResultWithStructuredOutput)
assert result.structured_output == {"name": "test"}
def test_structured_output_with_pydantic_model_streaming_rejected():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
with pytest.raises(ValueError):
invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="test")],
output_model=ExampleOutput,
stream=True,
)
def test_structured_output_with_pydantic_model_validation_error():
model_schema = get_model_entity("openai", "gpt-4o", support_structure_output=True)
model_instance = get_model_instance()
model_instance.invoke_llm.return_value = LLMResult(
model="gpt-4o",
message=AssistantPromptMessage(content='{"name": 123}'),
usage=create_mock_usage(prompt_tokens=8, completion_tokens=4),
)
with pytest.raises(OutputParserError):
invoke_llm_with_pydantic_model(
provider="openai",
model_schema=model_schema,
model_instance=model_instance,
prompt_messages=[UserPromptMessage(content="test")],
output_model=ExampleOutput,
stream=False,
)

View File

@ -203,7 +203,7 @@ const Annotation: FC<Props> = (props) => {
</Filter>
{isLoading
? <Loading type="app" />
// eslint-disable-next-line sonarjs/no-nested-conditional
: total > 0
? (
<List

View File

@ -134,6 +134,7 @@ const GetAutomaticRes: FC<IGetAutomaticResProps> = ({
},
] as const
// eslint-disable-next-line sonarjs/no-nested-template-literals, sonarjs/no-nested-conditional
const [instructionFromSessionStorage, setInstruction] = useSessionStorageState<string>(`improve-instruction-${flowId}${isBasicMode ? '' : `-${nodeId}${editorId ? `-${editorId}` : ''}`}`)
const instruction = instructionFromSessionStorage || ''
const [ideaOutput, setIdeaOutput] = useState<string>('')

View File

@ -175,7 +175,7 @@ describe('SettingsModal', () => {
renderSettingsModal()
fireEvent.click(screen.getByText('appOverview.overview.appInfo.settings.more.entry'))
const privacyInput = screen.getByPlaceholderText('appOverview.overview.appInfo.settings.more.privacyPolicyPlaceholder')
// eslint-disable-next-line sonarjs/no-clear-text-protocols
fireEvent.change(privacyInput, { target: { value: 'ftp://invalid-url' } })
fireEvent.click(screen.getByText('common.operation.save'))

View File

@ -14,6 +14,7 @@ import { BlockEnum } from '@/app/components/workflow/types'
import { useAppContext } from '@/context/app-context'
import { useDocLink } from '@/context/i18n'
import {
useAppTriggers,
useInvalidateAppTriggers,
useUpdateTriggerStatus,

Some files were not shown because too many files have changed in this diff Show More