diff --git a/api/.importlinter b/api/.importlinter index a836d09088..81a7b01c1b 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,10 +1,14 @@ [importlinter] root_packages = core + constants + context dify_graph configs controllers extensions + factories + libs models tasks services @@ -33,29 +37,19 @@ ignore_imports = # TODO(QuantumGhost): fix the import violation later dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - [importlinter:contract:workflow-external-imports] name = Workflow External Imports type = forbidden source_modules = dify_graph forbidden_modules = + constants configs + context controllers extensions + factories + libs models services tasks @@ -88,46 +82,14 @@ forbidden_modules = core.tools core.trigger core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids + +[importlinter:contract:workflow-third-party-imports] +name = Workflow Third-Party Imports +type = forbidden +source_modules = + dify_graph +forbidden_modules = + sqlalchemy [importlinter:contract:rsc] name = RSC diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d..7957eb0076 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `dify_graph` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f0..dd825c2f91 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `dify_graph` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b4..eddd6448d8 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 100% rename from api/dify_graph/context/models.py rename to api/context/models.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a57..764f9f8ee2 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index e9bd30ba7e..8bb5aa2c1b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -88,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -127,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index d59aa44718..31a48b1a06 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -20,6 +20,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE @@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" @@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence mappings=files, tenant_id=workflow.tenant_id, config=file_extra_config, + access_controller=_file_access_controller, ) return file_objs diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382..6b51f2d1f1 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -15,7 +15,8 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.file import helpers as file_helpers from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -389,13 +391,21 @@ class VariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395e..aa6a6ed843 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -12,6 +12,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError +from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db @@ -496,6 +497,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None @@ -514,7 +518,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 27c772fbe0..8a191d83c6 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -25,7 +25,7 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo @@ -332,7 +332,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -446,7 +446,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 897724182f..697c81e784 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -454,7 +454,7 @@ class DatasetInitApi(Resource): if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 7333fcaa07..9faf544e5d 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -283,7 +283,7 @@ class DatasetDocumentSegmentApi(Resource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -336,7 +336,7 @@ class DatasetDocumentSegmentAddApi(Resource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -387,7 +387,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -572,7 +572,7 @@ class ChildChunkAddApi(Resource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5..0738850251 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -21,7 +21,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings @@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() def _create_pagination_parser(): @@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c..f9e20eddda 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b..2f1e2f28bd 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1..ed3278a28b 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e6..4339ec0513 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -28,7 +28,7 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin +from core.tools.signature import get_signed_file_url_for_plugin from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import length_prefixed_response from models import Account, Tenant diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 25b6436a71..09a0be123f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -14,7 +14,7 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.index_processor.constant.index_type import IndexTechniqueType from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields @@ -140,10 +140,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -259,10 +259,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 595b01a9f2..6ab60be005 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -106,7 +106,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -160,7 +160,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -266,7 +266,7 @@ class DatasetSegmentApi(DatasetApiResource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -361,7 +361,7 @@ class ChildChunkApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bdc8df813..9ff4e6afde 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.app.file_access import DatabaseFileAccessController from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -46,6 +47,7 @@ from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class BaseAgentRunner(AppRunner): @@ -138,6 +140,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description @@ -524,7 +527,10 @@ class BaseAgentRunner(AppRunner): image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config + message_files=files, + tenant_id=self.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) if not file_objs: return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bd..7abac06dde 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -122,7 +122,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b21..d80f3579b9 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -96,7 +96,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0..3d3cccd45d 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -4,7 +4,7 @@ from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from dify_graph.model_runtime.entities.llm_entities import LLMMode from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel @@ -21,7 +21,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e33..cc75effe1f 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -2,9 +2,8 @@ from collections.abc import Mapping from typing import Any from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -54,9 +53,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -71,8 +73,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff..1027b45600 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -24,6 +24,7 @@ from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,13 +35,9 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db @@ -150,85 +147,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id ) - else: - file_objs = [] - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True # type: ignore - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=str(workflow_run_id), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True # type: ignore + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) - - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - return self._generate( - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - stream=streaming, - pause_state_config=pause_state_config, - ) + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + stream=streaming, + pause_state_config=pause_state_config, + ) def resume( self, @@ -460,94 +459,90 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + is_first_conversation = conversation is None - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) - # get conversation dialogue count - # NOTE: dialogue_count should not start from 0, - # because during the first conversation, dialogue_count should be 1. - self._dialogue_count = get_thread_messages_length(conversation.id) + 1 + # get conversation dialogue count + # NOTE: dialogue_count should not start from 0, + # because during the first conversation, dialogue_count should be 1. + self._dialogue_count = get_thread_messages_length(conversation.id) + 1 - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": context, - "variable_loader": variable_loader, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": context, + "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) - # db.session.refresh(user) - db.session.close() + worker_thread.start() - # return response or stream generator - response = self._handle_advanced_chat_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), - ) + # release database connection, because the following new thread operations may take a long time + with Session(bind=db.engine, expire_on_commit=False) as session: + workflow = _refresh_model(session, workflow) + message = _refresh_model(session, message) + db.session.close() - return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af..6164001324 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -25,14 +25,19 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.enums import WorkflowType from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from dify_graph.variables.variables import Variable from extensions.ext_database import db @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -150,7 +155,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +174,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +196,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d33..148af4f70c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -14,6 +14,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,14 +66,14 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import WorkflowExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile @@ -117,7 +118,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -741,8 +742,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -933,21 +935,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - message_files = [ - MessageFile( - message_id=message.id, - type=file["type"], - transfer_method=file["transfer_method"], - url=file["remote_url"], - belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], - created_by_role=CreatorUserRole.ACCOUNT - if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatorUserRole.END_USER, - created_by=message.from_account_id or message.from_end_user_id or "", + message_files: list[MessageFile] = [] + for file in self._recorded_files: + reference = file.get("reference") or file.get("related_id") + message_files.append( + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to=MessageFileBelongsTo.ASSISTANT, + upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None), + created_by_role=CreatorUserRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) ) - for file in self._recorded_files - ] session.add_all(message_files) def _seed_graph_runtime_state_from_queue_manager(self) -> None: @@ -1003,13 +1007,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 76a067d7b6..2c47d29356 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args.get("files") or [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) - # init application generate entity - application_generate_entity = AgentChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - call_depth=0, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) - # new thread with request context and contextvars - context = contextvars.copy_context() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": context, - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea..0e583d088d 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,17 +1,20 @@ from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope +from dify_graph.enums import NodeType +from dify_graph.file import File, FileUploadConfig from dify_graph.variables.input_entities import VariableEntityType +from extensions.ext_database import db from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser @@ -21,7 +24,66 @@ if TYPE_CHECKING: from dify_graph.variables.input_entities import VariableEntity +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) + + class BaseAppGenerator: + _file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController() + + @staticmethod + def _bind_file_access_scope( + *, + tenant_id: str, + user: Account | EndUser, + invoke_from: InvokeFrom, + ) -> AbstractContextManager[None]: + """Bind request-scoped file ownership markers for downstream file lookups.""" + + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return nullcontext() + + user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER + return bind_file_access_scope( + FileAccessScope( + tenant_id=tenant_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + ) + def _prepare_user_inputs( self, *, @@ -50,6 +112,7 @@ class BaseAppGenerator: allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE @@ -64,6 +127,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, list) @@ -226,32 +290,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 5addd41815..c413f904ff 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -61,27 +61,30 @@ class AppQueueManager(ABC): listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time: int | float = 0 - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break + try: + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE - ) + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + finally: + self._graph_runtime_state = None # Release reference once consumers finish or close the generator. def stop_listen(self): """ @@ -90,7 +93,6 @@ class AppQueueManager(ABC): """ self._clear_task_belong_cache() self._q.put(None) - self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 91cf54c774..71fdf22829 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = ChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - stream=streaming, - ) + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc86..f8656aac02 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e436163..c2343c2108 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,6 +4,7 @@ from __future__ import annotations from typing import TYPE_CHECKING +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.runtime import GraphRuntimeState if TYPE_CHECKING: @@ -30,10 +31,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf3..0de9f36770 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -50,20 +51,21 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) from dify_graph.file import FILE_MODEL_IDENTITY, File from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment +from dify_graph.variables.variables import Variable from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now @@ -111,11 +113,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +135,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +320,23 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, str] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) responses: list[StreamResponse] = [] @@ -344,8 +356,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id), resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 002b914ef1..fb88faa18f 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras={}, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict - # parse files - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=message.message_files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) - else: - file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - inputs=message.inputs, - query=message.query, - files=list(file_objs), - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras={}, - ) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a4519879..a62a6ad0ab 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 69% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd..4963300e94 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - from dify_graph.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 44d10d79b8..fe61224ada 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,6 +28,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import resolve_file_record_id from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic @@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108..be4f1b5841 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -18,6 +18,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,11 +35,12 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb..78df0639bd 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -12,16 +12,16 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db @@ -112,7 +112,7 @@ class PipelineRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +142,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2..393c0acd72 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -17,6 +17,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,11 +31,9 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db @@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator): graph_engine_layers: Sequence[GraphEngineLayer] = (), pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: - files: Sequence[Mapping[str, Any]] = args.get("files") or [] + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - # parse files - # TODO(QuantumGhost): Move file parsing logic to the API controller layer - # for better separation of concerns. - # - # For implementation reference, see the `_parse_file` function and - # `DraftWorkflowNodeRunApi` class which handle this properly. - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - system_files = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, - strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, - ) - - # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow, - ) - - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, - user_id=user.id if isinstance(user, Account) else user.session_id, - ) - - inputs: Mapping[str, Any] = args["inputs"] - - extras = { - **extract_external_trace_id_from_args(args), - } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) - # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args - # trigger shouldn't prepare user inputs - if self._should_prepare_user_inputs(args): - inputs = self._prepare_user_inputs( - user_inputs=inputs, - variables=app_config.variables, + # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, tenant_id=app_model.tenant_id, + config=file_extra_config, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + access_controller=self._file_access_controller, ) - # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - inputs=inputs, - files=list(system_files), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - call_depth=call_depth, - trace_manager=trace_manager, - workflow_execution_id=workflow_run_id, - extras=extras, - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if triggered_from is not None: - # Use explicitly provided triggered_from (for async triggers) - workflow_triggered_from = triggered_from - elif invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - root_node_id=root_node_id, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, - ) + inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } + workflow_run_id = str(workflow_run_id or uuid.uuid4()) + # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args + # trigger shouldn't prepare user inputs + if self._should_prepare_user_inputs(args): + inputs = self._prepare_user_inputs( + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=inputs, + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_execution_id=workflow_run_id, + extras=extras, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + root_node_id=root_node_id, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, + ) def resume( self, @@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager - queue_manager = WorkflowAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode, - ) - - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - # release database connection, because the following new thread operations may take a long time - db.session.close() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": context, - "variable_loader": variable_loader, - "root_node_id": root_node_id, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # release database connection, because the following new thread operations may take a long time + db.session.close() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": context, + "variable_loader": variable_loader, + "root_node_id": root_node_id, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - draft_var_saver_factory=draft_var_saver_factory, - stream=streaming, - ) + draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + draft_var_saver_factory=draft_var_saver_factory, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b95..0cae506a4b 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -8,14 +8,15 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.enums import WorkflowType from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span @@ -96,7 +97,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +105,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +125,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index bd6e2a0302..1496763601 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -10,6 +10,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,11 +56,10 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9af..8a6c91a771 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -34,7 +34,16 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.entities.pause_reason import HumanInputRequired @@ -68,7 +77,6 @@ from dify_graph.graph_events import ( ) from dify_graph.graph_events.graph import GraphRunAbortedEvent from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -173,14 +181,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -272,6 +281,8 @@ class WorkflowBasedAppRunner: graph_config["edges"] = edge_configs + typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] + # Create required parameters for Graph.init graph_init_params = GraphInitParams( workflow_id=workflow.id, @@ -291,26 +302,15 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) - # init graph - graph = Graph.init( - graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True - ) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id target_node_config = None - for node in node_configs: - if node.get("id") == node_id: + for node in typed_node_configs: + if node["id"] == node_id: target_node_config = node break if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) @@ -319,12 +319,31 @@ class WorkflowBasedAppRunner: # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool + preload_node_creation_variables( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + selectors=[ + selector + for node_config in typed_node_configs + for selector in get_node_creation_preload_selectors( + node_type=node_config["data"].type, + node_data=node_config["data"], + ) + ], + ) + try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, @@ -340,6 +359,14 @@ class WorkflowBasedAppRunner: tenant_id=workflow.tenant_id, ) + # init graph after constructor-time context has been loaded + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) + + if not graph: + raise ValueError("graph not found in workflow") + return graph, variable_pool @staticmethod @@ -408,7 +435,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeRetryEvent( @@ -448,7 +479,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeSucceededEvent( @@ -466,6 +501,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunFailedEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeFailedEvent( node_execution_id=event.id, @@ -475,7 +515,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, @@ -483,6 +523,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunExceptionEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeExceptionEvent( node_execution_id=event.id, @@ -492,7 +537,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..f949c2409d 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -7,7 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validat from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.file import File, FileUploadConfig from dify_graph.model_runtime.entities.model_entities import AIModelEntity @@ -15,6 +14,9 @@ if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py new file mode 100644 index 0000000000..a75ab9781b --- /dev/null +++ b/api/core/app/file_access/__init__.py @@ -0,0 +1,11 @@ +from .controller import DatabaseFileAccessController +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope + +__all__ = [ + "DatabaseFileAccessController", + "FileAccessControllerProtocol", + "FileAccessScope", + "bind_file_access_scope", + "get_current_file_access_scope", +] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py new file mode 100644 index 0000000000..300c187083 --- /dev/null +++ b/api/core/app/file_access/controller.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable + +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, get_current_file_access_scope + + +class DatabaseFileAccessController(FileAccessControllerProtocol): + """Workflow-layer authorization helper for database-backed file lookups. + + Tenant scoping remains mandatory. When the current execution belongs to an + end user, the lookup is additionally constrained to that end user's file + ownership markers. + """ + + _scope_getter: Callable[[], FileAccessScope | None] + + def __init__( + self, + *, + scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope, + ) -> None: + self._scope_getter = scope_getter + + def current_scope(self) -> FileAccessScope | None: + return self._scope_getter() + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where( + UploadFile.created_by_role == CreatorUserRole.END_USER, + UploadFile.created_by == resolved_scope.user_id, + ) + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(UploadFile, file_id) + + stmt = self.apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(ToolFile, file_id) + + stmt = self.apply_tool_file_filters( + select(ToolFile).where(ToolFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) diff --git a/api/core/app/file_access/protocols.py b/api/core/app/file_access/protocols.py new file mode 100644 index 0000000000..8bb3eb9924 --- /dev/null +++ b/api/core/app/file_access/protocols.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Protocol + +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile + +from .scope import FileAccessScope + + +class FileAccessControllerProtocol(Protocol): + """Contract for applying access rules to file lookups. + + Implementations translate an optional execution scope into query constraints + and authorized record retrieval. The contract is intentionally limited to + ownership and tenancy rules for workflow-layer file access. + """ + + def current_scope(self) -> FileAccessScope | None: + """Return the scope active for the current execution, if one exists. + + Callers use this to decide whether embedded file metadata may be trusted + or whether a fresh authorized lookup is required. + """ + ... + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + """Return an upload-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + """Return a tool-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + """Load one authorized upload-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + """Load one authorized tool-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py new file mode 100644 index 0000000000..80d504ef1c --- /dev/null +++ b/api/core/app/file_access/scope.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + +_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( + "current_file_access_scope", + default=None, +) + + +@dataclass(frozen=True, slots=True) +class FileAccessScope: + """Request-scoped ownership context used by workflow-layer file lookups.""" + + tenant_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + @property + def requires_user_ownership(self) -> bool: + return self.user_from == UserFrom.END_USER + + +def get_current_file_access_scope() -> FileAccessScope | None: + return _current_file_access_scope.get() + + +@contextmanager +def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: + token = _current_file_access_scope.set(scope) + try: + yield + finally: + _current_file_access_scope.reset(token) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904..cedb87490b 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,19 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from dify_graph.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +27,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0b..93ab7ae9ce 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -6,6 +6,7 @@ from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events.base import GraphEngineEvent from dify_graph.graph_events.graph import GraphRunPausedEvent @@ -119,7 +120,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e1..da9c6bada9 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -5,6 +5,7 @@ from typing import Any, ClassVar from pydantic import TypeAdapter from core.db.session_factory import session_factory +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events.base import GraphEngineEvent from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent @@ -59,7 +60,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..5706487c34 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,34 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.nodes.llm.entities import ModelConfig from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.protocols import CredentialsProvider class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +53,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,18 +78,42 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( *, node_data_model: ModelConfig, credentials_provider: CredentialsProvider, - model_factory: ModelFactory, + model_factory: DifyModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -80,22 +127,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +146,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d27111..6f8166c067 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,33 +1,42 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal from configs import dify_config +from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file +from core.workflow.file_reference import parse_file_reference +from dify_graph.file.enums import FileTransferMethod from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol from dify_graph.file.runtime import set_workflow_file_runtime from extensions.ext_storage import storage +if TYPE_CHECKING: + from dify_graph.file.models import File + class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``dify_graph.file``.""" + """Production runtime wiring for ``dify_graph.file``. - @property - def files_url(self) -> str: - return dify_config.FILES_URL + Opaque file references are resolved back to canonical database records before + URLs are signed or storage keys are used. When a request-scoped file access + scope is present, those lookups additionally enforce tenant and end-user + ownership filters. + """ - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL + _file_access_controller: FileAccessControllerProtocol - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT + def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None: + self._file_access_controller = file_access_controller @property def multimodal_send_format(self) -> str: @@ -39,9 +48,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + self._assert_upload_file_access(upload_file_id=parsed_reference.record_id) + return sign_tool_file( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.TOOL_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._assert_upload_file_access(upload_file_id=upload_file_id) + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + + def _assert_upload_file_access(self, *, upload_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id) + if upload_file is None: + raise ValueError(f"Upload file {upload_file_id} not found") + + def _assert_tool_file_access(self, *, tool_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id) + if tool_file is None: + raise ValueError(f"Tool file {tool_file_id} not found") + def bind_dify_workflow_file_runtime() -> None: - set_workflow_file_runtime(DifyWorkflowFileRuntime()) + set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c40..4bbd229cbb 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, cast, final from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance @@ -75,7 +76,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, @@ -114,11 +115,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -127,3 +128,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index d95a378575..f423b18f8f 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -17,10 +17,12 @@ from typing import Any, Union from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution from dify_graph.enums import ( - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, @@ -43,8 +45,6 @@ from dify_graph.graph_events import ( NodeRunSucceededEvent, ) from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.datetime_utils import naive_utc_now @@ -372,10 +372,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = error if update_outputs: + projected_outputs = project_node_outputs_for_workflow_run( + node_type=domain_execution.node_type, + inputs=node_result.inputs, + outputs=node_result.outputs, + ) domain_execution.update_from_mapping( inputs=node_result.inputs, process_data=node_result.process_data, - outputs=node_result.outputs, + outputs=projected_outputs, metadata=node_result.metadata, ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666..c0ea1d6f65 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -25,12 +25,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +60,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +87,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +115,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 4fa941ae16..6cb9bb868c 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -6,6 +6,7 @@ from typing import Any, cast from sqlalchemy import select import contexts +from core.app.file_access import DatabaseFileAccessController from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.datasource_entities import ( @@ -24,10 +25,11 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.file import File +from dify_graph.file import File, get_file_type_by_mime_type from dify_graph.file.enums import FileTransferMethod, FileType from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory @@ -36,6 +38,7 @@ from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class DatasourceManager: @@ -279,11 +282,15 @@ class DatasourceManager: if datasource_file is not None: mapping = { "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(mime_type), + "type": get_file_type_by_mime_type(mime_type), "transfer_method": FileTransferMethod.TOOL_FILE, "url": url, } - file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + file_out = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) elif mtype == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) @@ -351,11 +358,10 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id)), size=upload_file.size, storage_key=upload_file.key, url=upload_file.source_url, diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27..96b3c2b86b 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -4,6 +4,7 @@ from mimetypes import guess_extension, guess_type from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager +from core.workflow.file_reference import parse_file_reference from dify_graph.file import File, FileTransferMethod, FileType from models.tools import ToolFile @@ -103,8 +104,14 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2ef..f8589a89ec 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2..297b4fa90e 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,7 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -19,6 +21,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType from dify_graph.model_runtime.entities.provider_entities import ( ConfigurateMethod, @@ -28,6 +31,7 @@ from dify_graph.model_runtime.entities.provider_entities import ( ) from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.model_runtime.runtime import ModelRuntime from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a4093..eb6967be33 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -4,10 +4,10 @@ from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType @@ -41,7 +41,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 06bc366081..c25d41fae5 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -291,20 +294,20 @@ class IndexingRunner: raise ValueError("Dataset not found.") if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: if indexing_technique == IndexTechniqueType.HIGH_QUALITY: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -574,7 +577,7 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -766,14 +769,14 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d9..eb38fb8125 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -62,7 +62,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +120,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +172,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +219,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +306,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +337,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +362,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -536,7 +536,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea..8945021d47 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1156a98af1..51fba1b3f4 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -4,6 +4,7 @@ from sqlalchemy import select from sqlalchemy.orm import sessionmaker from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController from core.model_manager import ModelInstance from core.prompt.utils.extract_thread_messages import extract_thread_messages from dify_graph.file import file_manager @@ -23,6 +24,8 @@ from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +_file_access_controller = DatabaseFileAccessController() + class TokenBufferMemory: def __init__( @@ -85,7 +88,10 @@ class TokenBufferMemory: # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( - message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config + message_file=message_file, + tenant_id=app_record.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) for message_file in message_files ] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..ab47facdab 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -7,11 +7,12 @@ from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.callbacks.base_callback import Callback from dify_graph.model_runtime.entities.llm_entities import LLMResult from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType from dify_graph.model_runtime.entities.rerank_entities import RerankResult from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError @@ -30,7 +31,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -49,6 +50,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ @@ -110,7 +118,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +129,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +140,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +150,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +160,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +176,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +204,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +222,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +229,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +247,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +276,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +284,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +298,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +307,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +315,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +329,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +348,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +367,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +387,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +460,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +489,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf4..78b0cd91c3 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -50,7 +50,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c..2a4527a349 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index f54461e99a..e354c3909a 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -271,8 +271,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1f..394b4e682c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2..c65440935e 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc2381..70b1fb67a8 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa..d7e8ee964b 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c..98b232410a 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bac..5f8c69ff58 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -30,10 +30,27 @@ from dify_graph.model_runtime.entities.message_entities import ( SystemPromptMessage, UserPromptMessage, ) +from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +58,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +73,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +111,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +133,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +173,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +191,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +205,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +214,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +236,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +250,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,36 +261,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -299,6 +310,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -306,7 +318,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -325,6 +337,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -337,6 +350,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -394,6 +408,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67..0585494269 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb..30c11bb740 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,6 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -16,12 +16,19 @@ from core.plugin.impl.base import BasePluginClient from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +44,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +58,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +79,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +88,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +112,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +126,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +152,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +171,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +184,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +200,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +217,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +227,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +243,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +259,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +269,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +285,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +301,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +311,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +327,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +342,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +367,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +385,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +397,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +413,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +431,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +443,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +458,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +474,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +485,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +503,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +518,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +547,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +562,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +587,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +602,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 0000000000..af2a2de3fd --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from dify_graph.model_runtime.runtime import ModelRuntime +from extensions.ext_redis import redis_client +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + if not sorted_credentials: + return cache_key + hashed_credentials = ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + return f"{cache_key}:{hashed_credentials}" + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 0000000000..efdf93576a --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.plugin.impl.model import PluginModelClient +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef099..a6aed8eeb5 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6d2be0ab7a..bcca7f6126 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast from sqlalchemy import select from sqlalchemy.exc import IntegrityError @@ -53,15 +55,25 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from dify_graph.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +139,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -255,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration @@ -321,7 +334,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( @@ -392,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.value, + model_type=model_type.to_origin_model_type(), provider_name=provider, model_name=model, ) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963a..ab86118900 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -52,11 +52,10 @@ class DataPostProcessor: documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) @@ -106,9 +105,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9d..2c449e9163 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae..300a2e14ed 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index cd27113245..660f37926d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -73,7 +73,7 @@ class DatasetDocumentStore: max_position = 0 embedding_model = None if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a055..1e14ac8e6d 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -21,9 +21,8 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: str | None = None): + def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance - self._user = user def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 726cc062f6..c1fcdd1e69 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,11 +8,12 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -27,6 +28,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor, Su from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols +from core.workflow.file_reference import build_file_reference from dify_graph.file import File, FileTransferMethod, FileType, file_manager from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( @@ -48,6 +50,8 @@ from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService +_file_access_controller = DatabaseFileAccessController() + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -410,7 +414,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) @@ -555,6 +559,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): file_obj = build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) file_objects.append(file_obj) except Exception as e: @@ -604,11 +609,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, ) diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 88acb75133..cc65262527 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -12,7 +12,6 @@ class BaseRerankRunner(ABC): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -21,7 +20,6 @@ class BaseRerankRunner(ABC): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52..313d45a62f 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -22,7 +22,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -31,10 +30,11 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -43,12 +43,12 @@ class RerankModelRunner(BaseRerankRunner): ) if not is_support_vision: if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) else: return documents else: rerank_result, unique_documents = self.fetch_multimodal_rerank( - query, documents, score_threshold, top_n, user, query_type + query, documents, score_threshold, top_n, query_type ) rerank_documents = [] @@ -73,7 +73,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> tuple[RerankResult, list[Document]]: """ Fetch text rerank @@ -81,7 +80,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ docs = [] @@ -103,7 +101,7 @@ class RerankModelRunner(BaseRerankRunner): unique_documents.append(document) rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -113,7 +111,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> tuple[RerankResult, list[Document]]: """ @@ -122,7 +119,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :param query_type: query type :return: rerank result """ @@ -168,7 +164,7 @@ class RerankModelRunner(BaseRerankRunner): documents = unique_documents if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access @@ -181,7 +177,7 @@ class RerankModelRunner(BaseRerankRunner): "content_type": DocType.IMAGE, } rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d1..9d7cd5304c 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ @@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 52061fd93d..cacddd5139 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -56,6 +56,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -160,7 +161,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -383,23 +384,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: @@ -517,11 +522,12 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), @@ -986,6 +992,24 @@ class DatasetRetrieval: ) ) + @staticmethod + def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None: + """Map runtime user source values to dataset query audit roles. + + Workflow run context uses the hyphenated ``end-user`` value, while + ``DatasetQuery.created_by_role`` persists the underscore-based + ``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an + unsupported value should be skipped instead of aborting retrieval. + """ + normalized_user_from = str(user_from).strip().lower().replace("-", "_") + if normalized_user_from == CreatorUserRole.ACCOUNT.value: + return CreatorUserRole.ACCOUNT + if normalized_user_from == CreatorUserRole.END_USER.value: + return CreatorUserRole.END_USER + + logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from) + return None + def _on_query( self, query: str | None, @@ -996,10 +1020,13 @@ class DatasetRetrieval: user_id: str, ): """ - Handle query. + Persist dataset query audit rows for retrieval requests. """ if not query and not attachment_ids: return + created_by_role = self._resolve_creator_user_role(user_from) + if created_by_role is None: + return dataset_queries = [] for dataset_id in dataset_ids: contents = [] @@ -1014,7 +1041,7 @@ class DatasetRetrieval: content=json.dumps(contents), source=DatasetQuerySource.APP, source_app_id=app_id, - created_by_role=CreatorUserRole(user_from), + created_by_role=created_by_role, created_by=user_id, ) dataset_queries.append(dataset_query) @@ -1411,7 +1438,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1457,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1559,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1569,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a7..6051802d33 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -3,13 +3,14 @@ from typing import Union from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import ModelType PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -119,6 +120,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -150,19 +152,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..cfa9962ea8 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7..a93727db2a 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -11,8 +11,8 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550..b0db967153 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -12,11 +12,11 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) +from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf..76caf527db 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032..5b8b4ec6a2 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,33 +2,23 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, + is_human_input_webapp_enabled, ) +from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models.account import Account, TenantAccountJoin @@ -36,6 +26,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +49,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, - ): + app_id: str | None = None, + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id + self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return is_human_input_webapp_enabled(form_config) + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, + app_id=app_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac..011a79b07b 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -9,9 +9,9 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities import WorkflowExecution from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from libs.helper import extract_tenant_id from models import ( diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc..c9d0bd8597 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -17,10 +17,10 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.entities import WorkflowNodeExecution from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_storage import storage from libs.helper import extract_tenant_id @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a..5154bc9805 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c746..0860df2225 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -22,6 +22,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +32,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab..0b32b74e95 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -20,13 +20,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +40,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 44f94c2723..e07ca0d919 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Any -from pytz import timezone as pytz_timezone +from pytz import timezone as pytz_timezone # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index d0a41b940f..dc49b64dd8 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 462e4be5ce..8045e4b980 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e23ae3b001..e2570811d6 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index bcf58394ba..64a2c697fe 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -53,6 +53,7 @@ class BuiltinTool(Tool): tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +70,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +84,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099deba..1807226924 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afc..363584acf8 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -14,7 +14,8 @@ import httpx from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference +from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +225,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,7 +247,7 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e3..9e58610f77 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,7 +5,7 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa from sqlalchemy import select @@ -31,7 +31,7 @@ from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -62,7 +62,7 @@ from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvi from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,23 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + @property + def provider_type(self) -> ToolProviderType: ... + + @property + def provider_id(self) -> str: ... + + @property + def tool_name(self) -> str: ... + + @property + def tool_configurations(self) -> Mapping[str, Any]: ... + + @property + def credential_id(self) -> str | None: ... + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -167,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -178,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -196,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -304,6 +324,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, @@ -321,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -344,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -352,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -364,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -375,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -405,7 +442,8 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -418,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -450,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -460,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, @@ -1015,7 +1056,7 @@ class ToolManager: cls, parameters: list[ToolParameter], variable_pool: Optional["VariablePool"], - tool_configurations: dict[str, Any], + tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: """ diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 75b923fd8b..87f925f4a4 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -66,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2d..1a3c491216 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -10,12 +11,15 @@ import pytz from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager +from core.workflow.file_reference import parse_file_reference from dify_graph.file import File, FileTransferMethod, FileType from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 373bd1b1c8..5d49bf9f23 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -34,11 +34,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -60,13 +61,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -79,7 +80,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -93,7 +99,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -137,7 +143,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a741..22b8a5206e 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -7,6 +7,7 @@ from typing import Any, cast from sqlalchemy import select +from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,14 +18,17 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError +from core.workflow.file_reference import resolve_file_record_id from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from factories.file_factory import build_from_mapping from models import Account, Tenant from models.model import App, EndUser +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WorkflowTool(Tool): @@ -288,16 +292,25 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + build_file_from_stored_mapping( + file_mapping=cast(Mapping[str, Any], f), + tenant_id=str(self.runtime.tenant_id), + ) + for f in file + if isinstance(f, Mapping) + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -325,6 +338,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=item, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: @@ -332,6 +346,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=value, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) @@ -340,9 +355,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 0000000000..c80acb3783 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 0000000000..41b24635d9 --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,299 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`dify_graph` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +import bleach +import markdown +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from dify_graph.runtime import VariablePool +from dify_graph.variables.consts import SELECTORS_LENGTH + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py new file mode 100644 index 0000000000..f124b321d4 --- /dev/null +++ b/api/core/workflow/human_input_forms.py @@ -0,0 +1,55 @@ +"""Shared helpers for workflow pause-time human input form lookups. + +Both controllers and streaming response converters need the same recipient +priority when exposing resume links for paused human input forms. Keep that +selection logic here so all API surfaces stay consistent. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.human_input import HumanInputFormRecipient, RecipientType + +_FORM_TOKEN_PRIORITY = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def load_form_tokens_by_form_id( + form_ids: Sequence[str], + *, + session: Session | None = None, +) -> dict[str, str]: + """Load the preferred access token for each human input form.""" + unique_form_ids = list(dict.fromkeys(form_ids)) + if not unique_form_ids: + return {} + + if session is not None: + return _load_form_tokens_by_form_id(session, unique_form_ids) + + with Session(bind=db.engine, expire_on_commit=False) as new_session: + return _load_form_tokens_by_form_id(new_session, unique_form_ids) + + +def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: + tokens_by_form_id: dict[str, tuple[int, str]] = {} + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) + if priority is None or not recipient.access_token: + continue + + candidate = (priority, recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a79..db94ea0693 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -9,8 +9,8 @@ from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -19,22 +19,31 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.file.file_manager import file_manager from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.node import Node @@ -44,14 +53,8 @@ from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.nodes.document_extractor import UnstructuredApiConfig from dify_graph.nodes.http_request import build_http_request_config from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import TemplateRenderer from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -229,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -264,11 +257,31 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), + ) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +297,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -299,6 +312,9 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -310,7 +326,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -321,22 +337,29 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +368,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +421,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +436,35 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance @@ -452,12 +477,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 0000000000..2cf13bff0d --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,670 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from dify_graph.file import FileTransferMethod, FileType +from dify_graph.model_runtime.entities import LLMMode +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.nodes.human_input.entities import HumanInputNodeData +from dify_graph.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from dify_graph.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from extensions.ext_database import db +from factories import file_factory +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + +if TYPE_CHECKING: + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + from dify_graph.file import File + from dify_graph.nodes.llm.file_saver import LLMFileSaver + from dify_graph.nodes.tool.entities import ToolNodeData + + +_file_access_controller = DatabaseFileAccessController() + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `dify_graph`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + access_controller=_file_access_controller, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `dify_graph` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: + self._model_instance.parameters = value + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "reference": build_file_reference(record_id=str(upload_file.id)), + "size": upload_file.size, + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + """Workflow adapter that resolves conversation scope outside `dify_graph`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter + + def create_file_by_raw( + self, + *, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `dify_graph` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.user_id, + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + runtime_binding = self._binding_from_handle(tool_runtime) + tool = runtime_binding.tool + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=runtime_binding.conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=runtime_binding.conversation_id, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository + + def _invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_dify_debug_email_recipient( + method, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, + ) + for method in methods + ] + + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + + return self._build_form_repository() + + def _build_form_repository(self) -> HumanInputFormRepository: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, +) -> LLMFileSaver: + from dify_graph.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404..9fd1365b39 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,8 +3,10 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -59,7 +61,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +73,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +99,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,20 +109,21 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -146,6 +150,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f58a5665f4..59133a1e56 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -6,10 +6,11 @@ from typing import Any, cast from sqlalchemy import select from sqlalchemy.orm import Session +from core.app.file_access import DatabaseFileAccessController from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod +from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( @@ -27,6 +28,8 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError +_file_access_controller = DatabaseFileAccessController() + class AgentMessageTransformer: def transform( @@ -37,6 +40,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +51,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +74,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -83,20 +89,23 @@ class AgentMessageTransformer: mapping = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } file = file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -111,6 +120,7 @@ class AgentMessageTransformer: file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..48fcd6a749 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -12,15 +12,14 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey +from core.workflow.system_variables import SystemVariableKey, get_system_text from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment from extensions.ext_database import db from models.model import Conversation @@ -38,6 +37,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -219,10 +224,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) @@ -232,9 +236,15 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -246,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a..cca2cb14b5 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,12 +1,15 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey +from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionMetadataKey from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser @@ -50,15 +53,14 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +133,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5b..f51abfdd8c 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -6,9 +6,10 @@ from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey +from dify_graph.enums import NodeExecutionType from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.template import Template @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140be..bdffc6cdfe 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -9,8 +9,10 @@ from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( @@ -160,7 +162,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -254,7 +256,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index e1311ab962..6a9a6f2f2e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -54,7 +54,7 @@ class KnowledgeRetrievalRequest(BaseModel): tenant_id: str = Field(description="Tenant unique identifier") user_id: str = Field(description="User unique identifier") app_id: str = Field(description="Application unique identifier") - user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + user_from: str = Field(description="User identity source for audit logging (e.g., 'account', 'end-user')") dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") query: str | None = Field(default=None, description="Query text for knowledge retrieval") retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668..be08a034df 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -2,7 +2,7 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey from dify_graph.node_events import NodeRunResult @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab1..1d53c43eb5 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,7 +1,7 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType from dify_graph.node_events import NodeRunResult @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbda..e2dee3656a 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -3,15 +3,16 @@ from collections.abc import Mapping from typing import Any from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.enums import NodeExecutionType from dify_graph.file import FileTransferMethod from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol from dify_graph.variables.types import SegmentType from dify_graph.variables.variables import FileVariable -from factories import file_factory from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData @@ -23,6 +24,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -53,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -70,24 +76,20 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 0000000000..888d75202a --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.variables import build_segment, segment_to_variable +from dify_graph.variables.segments import Segment +from dify_graph.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +class _VariablePoolWriter(_VariablePoolReader, Protocol): + def add(self, selector: Sequence[str], value: object, /) -> None: ... + + +class _VariableLoader(Protocol): + def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset( + ( + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + ) +) + + +def get_node_creation_preload_selectors( + *, + node_type: str, + node_data: object, +) -> tuple[tuple[str, str], ...]: + """Return selectors that must exist before node construction begins.""" + + if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None: + return () + + return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),) + + +def preload_node_creation_variables( + *, + variable_loader: _VariableLoader, + variable_pool: _VariablePoolWriter, + selectors: Sequence[Sequence[str]], +) -> None: + """Load constructor-time variables before node or graph creation.""" + + seen_selectors: set[tuple[str, ...]] = set() + selectors_to_load: list[list[str]] = [] + for selector in selectors: + normalized_selector = tuple(selector) + if len(normalized_selector) < 2: + raise ValueError(f"Invalid preload selector: {selector}") + if normalized_selector in seen_selectors: + continue + seen_selectors.add(normalized_selector) + if variable_pool.get(normalized_selector) is None: + selectors_to_load.append(list(normalized_selector)) + + loaded_variables = variable_loader.load_variables(selectors_to_load) + for variable in loaded_variables: + raw_selector = getattr(variable, "selector", ()) + loaded_selector = list(raw_selector) + if len(loaded_selector) < 2: + raise ValueError(f"Invalid loaded variable selector: {raw_selector}") + variable_pool.add(loaded_selector[:2], variable) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `dify_graph` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 0000000000..fb7f9651ec --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor +from dify_graph.nodes.code.entities import CodeLanguage +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 0000000000..42151affdd --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from dify_graph.runtime import VariablePool +from dify_graph.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 100% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..d4a9cbd359 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,15 +1,24 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.errors import WorkflowNodeRunFailedError @@ -18,19 +27,18 @@ from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_engine.protocols.command_channel import CommandChannel from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: @@ -59,16 +67,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and only child-safe layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -79,17 +93,17 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), + graph_runtime_state=child_graph_runtime_state, + command_channel=command_channel, + config=config, child_engine_builder=self, ) child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -136,6 +150,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + execution_context = capture_current_context() + graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, @@ -212,6 +228,8 @@ class WorkflowEntry: # Get node type node_type = node_config_data.type + node_version = str(node_config_data.version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -226,15 +244,23 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + + preload_node_creation_variables( + variable_loader=variable_loader, + variable_pool=variable_pool, + selectors=get_node_creation_preload_selectors( + node_type=node_type, + node_data=node_config_data, + ), ) - node = node_factory.create_node(node_config) - node_cls = type(node) try: # variable selector to variable mapping @@ -243,6 +269,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -260,6 +292,13 @@ class WorkflowEntry: tenant_id=workflow.tenant_id, ) + # init workflow run state + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node = node_factory.create_node(node_config) + try: generator = cls._traced_node_run(node) except Exception as e: @@ -347,11 +386,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -366,7 +402,11 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +424,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -477,13 +523,21 @@ class WorkflowEntry: continue if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: - input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mapping( + mapping=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) if ( isinstance(input_value, list) and all(isinstance(item, dict) for item in input_value) and all("type" in item and "transfer_method" in item for item in input_value) ): - input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mappings( + mappings=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: diff --git a/api/core/workflow/workflow_run_outputs.py b/api/core/workflow/workflow_run_outputs.py new file mode 100644 index 0000000000..354140a138 --- /dev/null +++ b/api/core/workflow/workflow_run_outputs.py @@ -0,0 +1,18 @@ +from collections.abc import Mapping +from typing import Any + +from dify_graph.enums import BuiltinNodeTypes, NodeType + + +def project_node_outputs_for_workflow_run( + *, + node_type: NodeType, + inputs: Mapping[str, Any], + outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Project internal node outputs onto the workflow-run public contract.""" + + if node_type == BuiltinNodeTypes.START: + return dict(inputs) + + return dict(outputs) diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526bec..0000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f2502..0000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py index 86d8c8ca16..d70110c446 100644 --- a/api/dify_graph/entities/pause_reason.py +++ b/api/dify_graph/entities/pause_reason.py @@ -18,7 +18,6 @@ class HumanInputRequired(BaseModel): form_content: str inputs: list[FormInput] = Field(default_factory=list) actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False node_id: str node_title: str @@ -33,13 +32,6 @@ class HumanInputRequired(BaseModel): # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - class SchedulingPause(BaseModel): TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE diff --git a/api/dify_graph/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py index 459ac46415..cf8ebf1e43 100644 --- a/api/dify_graph/entities/workflow_execution.py +++ b/api/dify_graph/entities/workflow_execution.py @@ -1,26 +1,23 @@ """ Domain entities for workflow execution. -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. +Models describe graph runtime state and avoid infrastructure-specific details. """ from __future__ import annotations from collections.abc import Mapping -from datetime import datetime +from datetime import UTC, datetime from typing import Any from pydantic import BaseModel, Field from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now class WorkflowExecution(BaseModel): """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. + Domain model for a workflow execution within the graph runtime. """ id_: str = Field(...) @@ -47,7 +44,7 @@ class WorkflowExecution(BaseModel): Calculate elapsed time in seconds. If workflow is not finished, use current time. """ - end_time = self.finished_at or naive_utc_now() + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) return (end_time - self.started_at).total_seconds() @classmethod diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py index bc7e0d02e5..3a7b02b299 100644 --- a/api/dify_graph/entities/workflow_node_execution.py +++ b/api/dify_graph/entities/workflow_node_execution.py @@ -1,9 +1,8 @@ """ Domain entities for workflow node execution. -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. +These models capture node-level execution state for the graph runtime without +describing storage or application-layer concerns. """ from collections.abc import Mapping @@ -19,13 +18,8 @@ class WorkflowNodeExecution(BaseModel): """ Domain model for workflow node execution. - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. + This model represents the graph-level record of a node execution and + contains only execution state relevant to the runtime. """ # --------- Core identification fields --------- @@ -41,7 +35,7 @@ class WorkflowNodeExecution(BaseModel): # In most scenarios, `id` should be used as the primary identifier. node_execution_id: str | None = None workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) + workflow_execution_id: str | None = None # ID of the workflow execution (null for single-step debugging) # --------- Core identification fields ends --------- # Execution positioning and flow diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py index cfb135cbb0..94c7c57f18 100644 --- a/api/dify_graph/enums.py +++ b/api/dify_graph/enums.py @@ -10,30 +10,6 @@ class NodeState(StrEnum): SKIPPED = "skipped" -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - NodeType: TypeAlias = str diff --git a/api/dify_graph/file/__init__.py b/api/dify_graph/file/__init__.py index 44749ebec3..4908ae9795 100644 --- a/api/dify_graph/file/__init__.py +++ b/api/dify_graph/file/__init__.py @@ -1,5 +1,6 @@ from .constants import FILE_MODEL_IDENTITY from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType +from .file_factory import get_file_type_by_mime_type, standardize_file_type from .models import ( File, FileUploadConfig, @@ -16,4 +17,6 @@ __all__ = [ "FileType", "FileUploadConfig", "ImageConfig", + "get_file_type_by_mime_type", + "standardize_file_type", ] diff --git a/api/dify_graph/file/constants.py b/api/dify_graph/file/constants.py index 0665ed7e0d..56b95b5f0d 100644 --- a/api/dify_graph/file/constants.py +++ b/api/dify_graph/file/constants.py @@ -1,3 +1,4 @@ +from collections.abc import Iterable from typing import Any # TODO(QuantumGhost): Refactor variable type identification. Instead of directly @@ -5,6 +6,42 @@ from typing import Any # this logic into a dedicated function. This would encapsulate the implementation # details of how different variable types are identified. FILE_MODEL_IDENTITY = "__dify__file__" +DEFAULT_MIME_TYPE = "application/octet-stream" +DEFAULT_EXTENSION = ".bin" + + +def _with_case_variants(extensions: Iterable[str]) -> frozenset[str]: + normalized = {extension.lower() for extension in extensions} + return frozenset(normalized | {extension.upper() for extension in normalized}) + + +IMAGE_EXTENSIONS = _with_case_variants({"jpg", "jpeg", "png", "webp", "gif", "svg"}) +VIDEO_EXTENSIONS = _with_case_variants({"mp4", "mov", "mpeg", "webm"}) +AUDIO_EXTENSIONS = _with_case_variants({"mp3", "m4a", "wav", "amr", "mpga"}) +DOCUMENT_EXTENSIONS = _with_case_variants( + { + "txt", + "markdown", + "md", + "mdx", + "pdf", + "html", + "htm", + "xlsx", + "xls", + "vtt", + "properties", + "doc", + "docx", + "csv", + "eml", + "msg", + "ppt", + "pptx", + "xml", + "epub", + } +) def maybe_file_object(o: Any) -> bool: diff --git a/api/dify_graph/file/file_factory.py b/api/dify_graph/file/file_factory.py new file mode 100644 index 0000000000..3d20b9377d --- /dev/null +++ b/api/dify_graph/file/file_factory.py @@ -0,0 +1,39 @@ +from .constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS +from .enums import FileType + + +def standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: + """ + Infer the actual file type from extension and mime type. + """ + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = get_file_type_by_mime_type(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + normalized_extension = extension.lstrip(".") + if normalized_extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + if normalized_extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + if normalized_extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + if normalized_extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + return None + + +def get_file_type_by_mime_type(mime_type: str) -> FileType: + if "image" in mime_type: + return FileType.IMAGE + if "video" in mime_type: + return FileType.VIDEO + if "audio" in mime_type: + return FileType.AUDIO + if "text" in mime_type or "pdf" in mime_type: + return FileType.DOCUMENT + return FileType.CUSTOM diff --git a/api/dify_graph/file/file_manager.py b/api/dify_graph/file/file_manager.py index 8d998054db..45b86d91cc 100644 --- a/api/dify_graph/file/file_manager.py +++ b/api/dify_graph/file/file_manager.py @@ -12,7 +12,6 @@ from dify_graph.model_runtime.entities import ( ) from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from . import helpers from .enums import FileAttribute from .models import File, FileTransferMethod, FileType from .runtime import get_workflow_file_runtime @@ -80,7 +79,7 @@ def download(f: File, /) -> bytes: FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE, ): - return _download_file_content(f.storage_key) + return _download_file_content(f) elif f.transfer_method == FileTransferMethod.REMOTE_URL: if f.remote_url is None: raise ValueError("Missing file remote_url") @@ -90,12 +89,9 @@ def download(f: File, /) -> bytes: raise ValueError(f"unsupported transfer method: {f.transfer_method}") -def _download_file_content(path: str, /) -> bytes: +def _download_file_content(file: File, /) -> bytes: """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data + return get_workflow_file_runtime().load_file_bytes(file=file) def _get_encoded_string(f: File, /) -> str: @@ -107,30 +103,20 @@ def _get_encoded_string(f: File, /) -> str: response.raise_for_status() data = response.content case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) + data = _download_file_content(f) return base64.b64encode(data).decode("utf-8") def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: + url = f.generate_url() + if url is None: raise ValueError(f"Unsupported transfer method: {f.transfer_method}") + return url class FileManager: diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py index 310cb1310b..dade761227 100644 --- a/api/dify_graph/file/helpers.py +++ b/api/dify_graph/file/helpers.py @@ -1,92 +1,48 @@ from __future__ import annotations -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse +from typing import TYPE_CHECKING from .runtime import get_workflow_file_runtime +if TYPE_CHECKING: + from .models import File + + +def resolve_file_url(file: File, /, *, for_external: bool = True) -> str | None: + return get_workflow_file_runtime().resolve_file_url(file=file, for_external=for_external) + def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" + return get_workflow_file_runtime().resolve_upload_file_url( + upload_file_id=upload_file_id, + as_attachment=as_attachment, + for_external=for_external, + ) def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().resolve_tool_file_url( + tool_file_id=tool_file_id, + extension=extension, + for_external=for_external, + ) def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="image", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout + return get_workflow_file_runtime().verify_preview_signature( + preview_kind="file", + file_id=upload_file_id, + timestamp=timestamp, + nonce=nonce, + sign=sign, + ) diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py index dcba00978e..570921003d 100644 --- a/api/dify_graph/file/models.py +++ b/api/dify_graph/file/models.py @@ -1,8 +1,9 @@ from __future__ import annotations +import base64 +import json from collections.abc import Mapping, Sequence from typing import Any -from uuid import UUID, uuid4 from pydantic import BaseModel, Field, model_validator @@ -12,6 +13,8 @@ from . import helpers from .constants import FILE_MODEL_IDENTITY from .enums import FileTransferMethod, FileType +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" @@ -44,57 +47,68 @@ class FileUploadConfig(BaseModel): number_limits: int = 0 -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") +def _parse_reference(reference: str | None) -> tuple[str | None, str | None]: + """Best-effort parser for record references and historical storage-key payloads.""" + if not reference: + return None, None - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return reference, None + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return reference, None + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return reference, None + + storage_key = payload.get("storage_key") + if not isinstance(storage_key, str): + storage_key = None + + return record_id, storage_key class File(BaseModel): + """Graph-owned file reference. + + The graph layer deliberately keeps only the metadata required to route, + serialize, and render files. Application ownership concerns such as + tenant/user/conversation identity stay in the workflow/storage layer. + """ + # NOTE: dify_model_identity is a special identifier used to distinguish between # new and old data formats during serialization and deserialization. dify_model_identity: str = FILE_MODEL_IDENTITY id: str | None = None # message file id - tenant_id: str type: FileType transfer_method: FileTransferMethod # If `transfer_method` is `FileTransferMethod.remote_url`, the # `remote_url` attribute must not be `None`. remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None + # Opaque workflow-layer reference for files resolved outside ``dify_graph``. + # New payloads only carry the backing record id; historical payloads may + # still include storage_key and must remain readable. + reference: str | None = None filename: str | None = None extension: str | None = Field(default=None, description="File extension, should contain dot") mime_type: str | None = None size: int = -1 - - # Those properties are private, should not be exposed to the outside. _storage_key: str def __init__( self, *, id: str | None = None, - tenant_id: str, + tenant_id: str | None = None, type: FileType, transfer_method: FileTransferMethod, remote_url: str | None = None, + reference: str | None = None, related_id: str | None = None, filename: str | None = None, extension: str | None = None, @@ -103,18 +117,23 @@ class File(BaseModel): storage_key: str | None = None, dify_model_identity: str | None = FILE_MODEL_IDENTITY, url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields + # Legacy compatibility fields - explicitly accept known extra fields tool_file_id: str | None = None, upload_file_id: str | None = None, datasource_file_id: str | None = None, ): + legacy_record_id = related_id or tool_file_id or upload_file_id or datasource_file_id + normalized_reference = reference + if normalized_reference is None and legacy_record_id is not None: + normalized_reference = str(legacy_record_id) + _, parsed_storage_key = _parse_reference(normalized_reference) + super().__init__( id=id, - tenant_id=tenant_id, type=type, transfer_method=transfer_method, remote_url=remote_url, - related_id=related_id, + reference=normalized_reference, filename=filename, extension=extension, mime_type=mime_type, @@ -122,12 +141,15 @@ class File(BaseModel): dify_model_identity=dify_model_identity, url=url, ) - self._storage_key = str(storage_key) + # Accept legacy constructor fields without promoting them back into the graph model. + _ = tenant_id + self._storage_key = storage_key or parsed_storage_key or "" def to_dict(self) -> Mapping[str, str | int | None]: data = self.model_dump(mode="json") return { **data, + "related_id": self.related_id, "url": self.generate_url(), } @@ -142,21 +164,7 @@ class File(BaseModel): return text def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None + return helpers.resolve_file_url(self, for_external=for_external) def to_plugin_parameter(self) -> dict[str, Any]: return { @@ -178,19 +186,29 @@ class File(BaseModel): if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): raise ValueError("Invalid file url") case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") + if not self.reference: + raise ValueError("Missing file reference") return self + @property + def related_id(self) -> str | None: + record_id, _ = _parse_reference(self.reference) + return record_id + + @related_id.setter + def related_id(self, value: str | None) -> None: + self.reference = value + @property def storage_key(self) -> str: - return self._storage_key + _, storage_key = _parse_reference(self.reference) + return storage_key or self._storage_key @storage_key.setter def storage_key(self, value: str) -> None: diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py index 24cbb42735..4246d5d6ee 100644 --- a/api/dify_graph/file/protocols.py +++ b/api/dify_graph/file/protocols.py @@ -1,7 +1,10 @@ from __future__ import annotations from collections.abc import Generator -from typing import Protocol +from typing import TYPE_CHECKING, Literal, Protocol + +if TYPE_CHECKING: + from .models import File class HttpResponseProtocol(Protocol): @@ -21,18 +24,6 @@ class WorkflowFileRuntimeProtocol(Protocol): application infrastructure modules directly. """ - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - @property def multimodal_send_format(self) -> str: ... @@ -40,4 +31,26 @@ class WorkflowFileRuntimeProtocol(Protocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + def load_file_bytes(self, *, file: File) -> bytes: ... + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: ... + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: ... + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: ... diff --git a/api/dify_graph/file/runtime.py b/api/dify_graph/file/runtime.py index 94253e0255..1c5d1c3ca4 100644 --- a/api/dify_graph/file/runtime.py +++ b/api/dify_graph/file/runtime.py @@ -1,10 +1,13 @@ from __future__ import annotations from collections.abc import Generator -from typing import NoReturn +from typing import TYPE_CHECKING, Literal, NoReturn from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +if TYPE_CHECKING: + from .models import File + class WorkflowFileRuntimeNotConfiguredError(RuntimeError): """Raised when workflow file runtime dependencies were not configured.""" @@ -16,22 +19,6 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" ) - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - @property def multimodal_send_format(self) -> str: self._raise() @@ -42,7 +29,33 @@ class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: self._raise() - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + self._raise() + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + self._raise() + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._raise() + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._raise() + + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: self._raise() diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py index 85117583e0..add69ad884 100644 --- a/api/dify_graph/graph/graph.py +++ b/api/dify_graph/graph/graph.py @@ -10,7 +10,6 @@ from pydantic import TypeAdapter from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState from dify_graph.nodes.base.node import Node -from libs.typing import is_str from .edge import Edge from .validation import get_graph_validator @@ -102,7 +101,7 @@ class Graph: source = edge_config.get("source") target = edge_config.get("target") - if not is_str(source) or not is_str(target): + if not isinstance(source, str) or not isinstance(target, str): continue # Create edge @@ -110,7 +109,7 @@ class Graph: edge_counter += 1 source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): + if not isinstance(source_handle, str): continue edge = Edge( diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py index 7f5ad40e0e..e7a8291373 100644 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ b/api/dify_graph/graph_engine/event_management/event_handlers.py @@ -28,6 +28,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState @@ -93,6 +94,10 @@ class EventHandler: Args: event: The event to handle """ + if isinstance(event, NodeRunVariableUpdatedEvent): + self._dispatch(event) + return + # Events in loops or iterations are always collected if event.in_loop_id or event.in_iteration_id: self._event_collector.collect(event) @@ -153,6 +158,17 @@ class EventHandler: for stream_event in streaming_events: self._event_collector.collect(stream_event) + @_dispatch.register + def _(self, event: NodeRunVariableUpdatedEvent) -> None: + """ + Apply a node-requested variable mutation before downstream observers run. + + The event is collected like other node events so parent/container engines can + forward the updated payload to outer layers, including persistence listeners. + """ + self._graph_runtime_state.variable_pool.add(event.variable.selector, event.variable) + self._event_collector.collect(event) + @_dispatch.register def _(self, event: NodeRunSucceededEvent) -> None: """ diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py index ea98a46b06..bfacd8ed71 100644 --- a/api/dify_graph/graph_engine/graph_engine.py +++ b/api/dify_graph/graph_engine/graph_engine.py @@ -9,10 +9,9 @@ from __future__ import annotations import logging import queue -from collections.abc import Generator, Mapping +from collections.abc import Generator from typing import TYPE_CHECKING, cast, final -from dify_graph.context import capture_current_context from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import NodeExecutionType from dify_graph.graph import Graph @@ -26,7 +25,7 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper +from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol if TYPE_CHECKING: # pragma: no cover - used only for static analysis @@ -86,6 +85,7 @@ class GraphEngine: self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) self._command_channel = command_channel self._config = config + self._layers: list[GraphEngineLayer] = [] self._child_engine_builder = child_engine_builder if child_engine_builder is not None: self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) @@ -149,21 +149,14 @@ class GraphEngine: update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - # Create worker pool for parallel node execution self._worker_pool = WorkerPool( ready_queue=self._ready_queue, event_queue=self._event_queue, graph=self._graph, layers=self._layers, - execution_context=execution_context, + execution_context=self._graph_runtime_state.execution_context, config=self._config, ) @@ -220,23 +213,23 @@ class GraphEngine: self._bind_layer_context(layer) return self + def request_abort(self, reason: str | None = None) -> None: + """Queue an abort command for this engine.""" + self._command_channel.send_command(AbortCommand(reason=reason or "User requested abort")) + def create_child_engine( self, *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: return self._graph_runtime_state.create_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) def run(self) -> Generator[GraphEngineEvent, None, None]: diff --git a/api/dify_graph/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py index 988c20d72a..860a46e5c2 100644 --- a/api/dify_graph/graph_engine/worker.py +++ b/api/dify_graph/graph_engine/worker.py @@ -9,19 +9,18 @@ import queue import threading import time from collections.abc import Sequence -from datetime import datetime +from contextlib import AbstractContextManager +from datetime import UTC, datetime from typing import TYPE_CHECKING, final from typing_extensions import override -from dify_graph.context import IExecutionContext from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node -from libs.datetime_utils import naive_utc_now from .ready_queue import ReadyQueue @@ -46,7 +45,7 @@ class Worker(threading.Thread): graph: Graph, layers: Sequence[GraphEngineLayer], worker_id: int = 0, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize worker thread. @@ -187,7 +186,7 @@ class Worker(threading.Thread): self, node: Node, error: Exception, *, started_at: datetime | None = None ) -> NodeRunFailedEvent: """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = naive_utc_now() + failure_time = datetime.now(UTC).replace(tzinfo=None) error_message = str(error) return NodeRunFailedEvent( id=node.execution_id, diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py index cc93087783..eb641165a0 100644 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ b/api/dify_graph/graph_engine/worker_management/worker_pool.py @@ -8,9 +8,9 @@ DynamicScaler, and WorkerFactory into a single class. import logging import queue import threading +from contextlib import AbstractContextManager from typing import final -from dify_graph.context import IExecutionContext from dify_graph.graph import Graph from dify_graph.graph_events import GraphNodeEventBase @@ -38,7 +38,7 @@ class WorkerPool: graph: Graph, layers: list[GraphEngineLayer], config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: """ Initialize the simple worker pool. diff --git a/api/dify_graph/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py index 56ea642092..7cec587a05 100644 --- a/api/dify_graph/graph_events/__init__.py +++ b/api/dify_graph/graph_events/__init__.py @@ -46,6 +46,7 @@ from .node import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, is_node_result_event, ) @@ -78,5 +79,6 @@ __all__ = [ "NodeRunStartedEvent", "NodeRunStreamChunkEvent", "NodeRunSucceededEvent", + "NodeRunVariableUpdatedEvent", "is_node_result_event", ] diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py index df19d6c03b..3b880f5f6f 100644 --- a/api/dify_graph/graph_events/node.py +++ b/api/dify_graph/graph_events/node.py @@ -1,10 +1,11 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from datetime import datetime +from typing import Any from pydantic import Field -from core.rag.entities.citation_metadata import RetrievalSourceMetadata from dify_graph.entities.pause_reason import PauseReason +from dify_graph.variables.variables import Variable from .base import GraphNodeEventBase @@ -30,7 +31,7 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase): class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") + retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") context: str = Field(..., description="context") @@ -39,6 +40,12 @@ class NodeRunSucceededEvent(GraphNodeEventBase): finished_at: datetime | None = Field(default=None, description="node finish time") +class NodeRunVariableUpdatedEvent(GraphNodeEventBase): + """Request that the engine apply a variable update before downstream observers continue.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class NodeRunFailedEvent(GraphNodeEventBase): error: str = Field(..., description="error") start_at: datetime = Field(..., description="node start time") diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py index 20faf3d6cd..76af06b67e 100644 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ b/api/dify_graph/model_runtime/callbacks/base_callback.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool @@ -34,6 +34,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -46,7 +47,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -63,6 +65,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -76,7 +79,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -93,6 +97,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -106,7 +111,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() @@ -123,6 +129,7 @@ class Callback(ABC): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -136,7 +143,8 @@ class Callback(ABC): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ raise NotImplementedError() diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py index 49b9ab27eb..03ab9534f7 100644 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ b/api/dify_graph/model_runtime/callbacks/logging_callback.py @@ -1,7 +1,7 @@ import json import logging import sys -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import cast from dify_graph.model_runtime.callbacks.base_callback import Callback @@ -24,6 +24,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Before invoke callback @@ -36,7 +37,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param user: optional end-user identifier for the invocation + :param invocation_context: opaque request metadata for the current invocation """ self.print_text("\n[on_llm_before_invoke]\n", color="blue") self.print_text(f"Model: {model}\n", color="blue") @@ -53,10 +55,12 @@ class LoggingCallback(Callback): self.print_text(f"\t\t{tool.name}\n", color="blue") self.print_text(f"Stream: {stream}\n", color="blue") - if user: self.print_text(f"User: {user}\n", color="blue") + if invocation_context: + self.print_text(f"Invocation context: {dict(invocation_context)}\n", color="blue") + self.print_text("Prompt messages:\n", color="blue") for prompt_message in prompt_messages: if prompt_message.name: @@ -80,6 +84,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ On new chunk callback @@ -93,8 +98,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context sys.stdout.write(cast(str, chunk.delta.message.content)) sys.stdout.flush() @@ -110,6 +116,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ After invoke callback @@ -123,8 +130,9 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_after_invoke]\n", color="yellow") self.print_text(f"Content: {result.message.content}\n", color="yellow") @@ -151,6 +159,7 @@ class LoggingCallback(Callback): stop: Sequence[str] | None = None, stream: bool = True, user: str | None = None, + invocation_context: Mapping[str, object] | None = None, ): """ Invoke error callback @@ -164,7 +173,8 @@ class LoggingCallback(Callback): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ + _ = user, invocation_context self.print_text("\n[on_llm_invoke_error]\n", color="red") logger.exception(ex) diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py index 97a99ea7ce..55e1775826 100644 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ b/api/dify_graph/model_runtime/entities/provider_entities.py @@ -93,10 +93,14 @@ class ModelCredentialSchema(BaseModel): class SimpleProviderEntity(BaseModel): """ - Simple model class for provider. + Simplified provider schema exposed to callers. + + `provider` is the canonical runtime identifier. `provider_name` is an optional + compatibility alias for short-name lookups and is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject icon_small: I18nObject | None = None icon_small_dark: I18nObject | None = None @@ -115,10 +119,15 @@ class ProviderHelpEntity(BaseModel): class ProviderEntity(BaseModel): """ - Model class for provider. + Runtime-native provider schema. + + `provider` is the canonical runtime identifier. `provider_name` is a + compatibility alias for callers that still resolve providers by short name and + is empty when no alias exists. """ provider: str + provider_name: str = "" label: I18nObject description: I18nObject | None = None icon_small: I18nObject | None = None @@ -153,6 +162,7 @@ class ProviderEntity(BaseModel): """ return SimpleProviderEntity( provider=self.provider, + provider_name=self.provider_name, label=self.label, icon_small=self.icon_small, supported_model_types=self.supported_model_types, diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py index 99709e1bcd..8a0bb5fac2 100644 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ b/api/dify_graph/model_runtime/entities/rerank_entities.py @@ -1,6 +1,13 @@ +from typing import TypedDict + from pydantic import BaseModel +class MultimodalRerankInput(TypedDict): + content: str + content_type: str + + class RerankDocument(BaseModel): """ Model class for rerank document. diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py index a0210c169d..7954b410d0 100644 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ b/api/dify_graph/model_runtime/entities/text_embedding_entities.py @@ -1,10 +1,18 @@ from decimal import Decimal +from enum import StrEnum, auto from pydantic import BaseModel from dify_graph.model_runtime.entities.model_entities import ModelUsage +class EmbeddingInputType(StrEnum): + """Embedding request input variants understood by the model runtime.""" + + DOCUMENT = auto() + QUERY = auto() + + class EmbeddingUsage(ModelUsage): """ Model class for embedding usage. diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py index ac7ae9925b..cfce0b0602 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py @@ -1,12 +1,5 @@ import decimal -import hashlib -import logging -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE from dify_graph.model_runtime.entities.model_entities import ( @@ -17,6 +10,7 @@ from dify_graph.model_runtime.entities.model_entities import ( PriceInfo, PriceType, ) +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity from dify_graph.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, @@ -25,45 +19,61 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeRateLimitError, InvokeServerUnavailableError, ) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) +from dify_graph.model_runtime.runtime import ModelRuntime -class AIModel(BaseModel): +class AIModel: """ - Base class for all models. + Runtime-facing base class for all model providers. + + This stays a regular Python class because instances hold live collaborators + such as the provider schema and runtime adapter rather than user input that + benefits from Pydantic validation. Subclasses must pin ``model_type`` via a + class attribute; the base class is not meant to be instantiated directly. """ - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) + model_type: ModelType + provider_schema: ProviderEntity + model_runtime: ModelRuntime + started_at: float - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) + def __init__( + self, + provider_schema: ProviderEntity, + model_runtime: ModelRuntime, + *, + started_at: float = 0, + ) -> None: + if getattr(type(self), "model_type", None) is None: + raise TypeError("AIModel subclasses must define model_type as a class attribute") + + self.model_type = type(self).model_type + self.provider_schema = provider_schema + self.model_runtime = model_runtime + self.started_at = started_at + + @property + def provider(self) -> str: + return self.provider_schema.provider + + @property + def provider_display_name(self) -> str: + return self.provider_schema.label.en_US @property def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. + Map model invoke error to unified error. - :return: Invoke error mapping + The key is the error type thrown to the caller, and the value contains + runtime-facing exception types that should be normalized to it. """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - return { InvokeConnectionError: [InvokeConnectionError], InvokeServerUnavailableError: [InvokeServerUnavailableError], InvokeRateLimitError: [InvokeRateLimitError], InvokeAuthorizationError: [InvokeAuthorizationError], InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], ValueError: [ValueError], } @@ -79,15 +89,18 @@ class AIModel(BaseModel): if invoke_error == InvokeAuthorizationError: return InvokeAuthorizationError( description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." + f"[{self.provider_display_name}] Incorrect model credentials provided, " + "please check and try again." ) ) elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") + return InvokeError( + description=f"[{self.provider_display_name}] {invoke_error.description}, {str(error)}" + ) else: return error - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") + return InvokeError(description=f"[{self.provider_display_name}] Error: {str(error)}") def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: """ @@ -144,65 +157,13 @@ class AIModel(BaseModel): :param credentials: model credentials :return: model schema """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, + return self.model_runtime.get_model_schema( + provider=self.provider, + model_type=self.model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: """ Get customizable model schema from credentials diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py index bf864ca227..21c07cf5bb 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py @@ -1,12 +1,9 @@ import logging import time import uuid -from collections.abc import Callable, Generator, Iterator, Sequence +from collections.abc import Callable, Generator, Iterator, Mapping, Sequence from typing import Union -from pydantic import ConfigDict - -from configs import dify_config from dify_graph.model_runtime.callbacks.base_callback import Callback from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage @@ -140,11 +137,9 @@ def _build_llm_result_from_chunks( ) -def _invoke_llm_via_plugin( +def _invoke_llm_via_runtime( *, - tenant_id: str, - user_id: str, - plugin_id: str, + llm_model: "LargeLanguageModel", provider: str, model: str, credentials: dict, @@ -154,25 +149,19 @@ def _invoke_llm_via_plugin( stop: Sequence[str] | None, stream: bool, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, + return llm_model.model_runtime.invoke_llm( provider=provider, model=model, credentials=credentials, model_parameters=model_parameters, prompt_messages=list(prompt_messages), tools=tools, - stop=list(stop) if stop else None, + stop=stop, stream=stream, ) -def _normalize_non_stream_plugin_result( +def _normalize_non_stream_runtime_result( model: str, prompt_messages: Sequence[PromptMessage], result: Union[LLMResult, Iterator[LLMResultChunk]], @@ -208,9 +197,6 @@ class LargeLanguageModel(AIModel): model_type: ModelType = ModelType.LLM - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, @@ -220,7 +206,6 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: """ @@ -233,7 +218,6 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -245,7 +229,7 @@ class LargeLanguageModel(AIModel): callbacks = callbacks or [] - if dify_config.DEBUG: + if logger.isEnabledFor(logging.DEBUG): callbacks.append(LoggingCallback()) # trigger before invoke callbacks @@ -257,18 +241,15 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + result = _invoke_llm_via_runtime( + llm_model=self, + provider=self.provider, model=model, credentials=credentials, model_parameters=model_parameters, @@ -279,7 +260,7 @@ class LargeLanguageModel(AIModel): ) if not stream: - result = _normalize_non_stream_plugin_result( + result = _normalize_non_stream_runtime_result( model=model, prompt_messages=prompt_messages, result=result ) except Exception as e: @@ -292,7 +273,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) @@ -309,7 +289,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) elif isinstance(result, LLMResult): @@ -322,7 +301,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) # Following https://github.com/langgenius/dify/issues/17799, @@ -342,7 +320,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunk, None, None]: """ @@ -384,7 +362,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -415,7 +393,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, callbacks=callbacks, ) @@ -435,22 +413,14 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :return: """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 + return self.model_runtime.get_llm_num_tokens( + provider=self.provider, + model_type=self.model_type, + model=model, + credentials=credentials, + prompt_messages=prompt_messages, + tools=tools, + ) def calc_response_usage( self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int @@ -504,7 +474,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -517,7 +487,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -532,7 +502,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -546,7 +516,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -560,7 +530,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation """ _run_callbacks( callbacks, @@ -575,7 +545,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -589,7 +559,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -603,7 +573,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -619,7 +589,7 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) @@ -633,7 +603,7 @@ class LargeLanguageModel(AIModel): tools: list[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, + invocation_context: Mapping[str, object] | None = None, callbacks: list[Callback] | None = None, ): """ @@ -647,7 +617,7 @@ class LargeLanguageModel(AIModel): :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id + :param invocation_context: opaque request metadata for the current invocation :param callbacks: callbacks """ _run_callbacks( @@ -663,6 +633,6 @@ class LargeLanguageModel(AIModel): tools=tools, stop=stop, stream=stream, - user=user, + invocation_context=invocation_context, ), ) diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py index 5fa3d1634b..6983971770 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py @@ -1,7 +1,5 @@ import time -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,30 +11,20 @@ class ModerationModel(AIModel): model_type: ModelType = ModelType.MODERATION - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: + def invoke(self, model: str, credentials: dict, text: str) -> bool: """ Invoke moderation model :param model: model name :param credentials: model credentials :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ self.started_at = time.perf_counter() try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_moderation( + provider=self.provider, model=model, credentials=credentials, text=text, diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py index 5da2b84b95..f448690677 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py @@ -1,5 +1,5 @@ from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -18,7 +18,6 @@ class RerankModel(AIModel): docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -29,18 +28,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, @@ -55,11 +47,10 @@ class RerankModel(AIModel): self, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke multimodal rerank model @@ -69,18 +60,11 @@ class RerankModel(AIModel): :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_rerank( + provider=self.provider, model=model, credentials=credentials, query=query, diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py index e69069a85d..0cfa5208e9 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py @@ -1,7 +1,5 @@ from typing import IO -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,28 +11,18 @@ class Speech2TextModel(AIModel): model_type: ModelType = ModelType.SPEECH2TEXT - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: + def invoke(self, model: str, credentials: dict, file: IO[bytes]) -> str: """ Invoke speech to text model :param model: model name :param credentials: model credentials :param file: audio file - :param user: unique user id :return: text for given audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_speech_to_text( + provider=self.provider, model=model, credentials=credentials, file=file, diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py index 3438da2ada..d03ff5f6c2 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py @@ -1,8 +1,5 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -13,16 +10,12 @@ class TextEmbeddingModel(AIModel): model_type: ModelType = ModelType.TEXT_EMBEDDING - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, credentials: dict, texts: list[str] | None = None, multimodel_documents: list[dict] | None = None, - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ @@ -32,31 +25,21 @@ class TextEmbeddingModel(AIModel): :param credentials: model credentials :param texts: texts to embed :param files: files to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ - from core.plugin.impl.model import PluginModelClient - try: - plugin_model_manager = PluginModelClient() if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_text_embedding( + provider=self.provider, model=model, credentials=credentials, texts=texts, input_type=input_type, ) if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_multimodal_embedding( + provider=self.provider, model=model, credentials=credentials, documents=multimodel_documents, @@ -75,14 +58,8 @@ class TextEmbeddingModel(AIModel): :param texts: texts to embed :return: """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_text_embedding_num_tokens( + provider=self.provider, model=model, credentials=credentials, texts=texts, diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py index 0656529f22..23f5640b37 100644 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py @@ -1,8 +1,6 @@ import logging from collections.abc import Iterable -from pydantic import ConfigDict - from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -16,38 +14,25 @@ class TTSModel(AIModel): model_type: ModelType = ModelType.TTS - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - def invoke( self, model: str, - tenant_id: str, credentials: dict, content_text: str, voice: str, - user: str | None = None, ) -> Iterable[bytes]: """ Invoke large language model :param model: model name - :param tenant_id: user tenant id :param credentials: model credentials :param voice: model timbre :param content_text: text content to be translated - :param user: unique user id :return: translated audio file """ try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.invoke_tts( + provider=self.provider, model=model, credentials=credentials, content_text=content_text, @@ -65,14 +50,8 @@ class TTSModel(AIModel): :param credentials: The credentials required to access the TTS model. :return: A list of voices supported by the TTS model. """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, + return self.model_runtime.get_tts_model_voices( + provider=self.provider, model=model, credentials=credentials, language=language, diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py index de0677a348..ae5c0ec7c4 100644 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py @@ -1,16 +1,7 @@ from __future__ import annotations -import hashlib -import logging from collections.abc import Sequence -from threading import Lock -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel @@ -20,120 +11,64 @@ from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankM from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel +from dify_graph.model_runtime.runtime import ModelRuntime from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( ProviderCredentialSchemaValidator, ) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient + """Factory for provider schemas and model-type instances backed by a runtime adapter.""" - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() + def __init__(self, model_runtime: ModelRuntime): + if model_runtime is None: + raise ValueError("model_runtime is required.") + self.model_runtime = model_runtime def get_providers(self) -> Sequence[ProviderEntity]: """ - Get all providers - :return: list of providers + Get all providers. """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] + return list(self.get_model_providers()) - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: + def get_model_providers(self) -> Sequence[ProviderEntity]: """ - Get all plugin model providers - :return: list of plugin model providers + Get all model providers exposed by the runtime adapter. """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers + return self.model_runtime.fetch_model_providers() def get_provider_schema(self, provider: str) -> ProviderEntity: """ - Get provider schema - :param provider: provider name - :return: provider schema + Get provider schema. """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration + return self.get_model_provider(provider=provider) - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: + def get_model_provider(self, provider: str) -> ProviderEntity: """ - Get plugin model provider - :param provider: provider name - :return: provider schema + Get provider schema. """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: + provider_entity = self._resolve_provider(provider) + if provider_entity is None: raise ValueError(f"Invalid provider: {provider}") - return plugin_model_provider_entity + return provider_entity def provider_credentials_validate(self, *, provider: str, credentials: dict): """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: + Validate provider credentials. """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) + provider_entity = self.get_model_provider(provider=provider) - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema + provider_credential_schema = provider_entity.provider_credential_schema if not provider_credential_schema: raise ValueError(f"Provider {provider} does not have provider_credential_schema") - # validate provider credential schema validator = ProviderCredentialSchemaValidator(provider_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, + self.model_runtime.validate_provider_credentials( + provider=provider_entity.provider, credentials=filtered_credentials, ) @@ -141,33 +76,20 @@ class ModelProviderFactory: def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: + Validate model credentials. """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) + provider_entity = self.get_model_provider(provider=provider) - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema + model_credential_schema = provider_entity.model_credential_schema if not model_credential_schema: raise ValueError(f"Provider {provider} does not have model_credential_schema") - # validate model credential schema validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, + self.model_runtime.validate_model_credentials( + provider=provider_entity.provider, + model_type=model_type, model=model, credentials=filtered_credentials, ) @@ -178,65 +100,16 @@ class ModelProviderFactory: self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None ) -> AIModelEntity | None: """ - Get model schema + Get model schema. """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_model_schema( + provider=provider_entity.provider, + model_type=model_type, model=model, credentials=credentials or {}, ) - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - def get_models( self, *, @@ -245,143 +118,56 @@ class ModelProviderFactory: provider_configs: list[ProviderConfig] | None = None, ) -> list[SimpleProviderEntity]: """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models + Get all models for given model type. """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: + for provider_entity in self.get_model_providers(): + if provider and not self._matches_provider(provider_entity, provider): continue - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models + if model_type and model_type not in provider_entity.supported_model_types: + continue + simple_provider_schema = provider_entity.to_simple_provider() + if model_type is not None: + simple_provider_schema.models = [ + model_schema for model_schema in provider_entity.models if model_schema.model_type == model_type + ] providers.append(simple_provider_schema) return providers def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance + Get model type instance by provider name and model type. """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } + provider_schema = self.get_model_provider(provider) if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) + return LargeLanguageModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TEXT_EMBEDDING: + return TextEmbeddingModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.RERANK: + return RerankModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.SPEECH2TEXT: + return Speech2TextModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.MODERATION: + return ModerationModel(provider_schema=provider_schema, model_runtime=self.model_runtime) + if model_type == ModelType.TTS: + return TTSModel(provider_schema=provider_schema, model_runtime=self.model_runtime) raise ValueError(f"Unsupported model type: {model_type}") def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon + Get provider icon. """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) + provider_entity = self.get_model_provider(provider) + return self.model_runtime.get_provider_icon(provider=provider_entity.provider, icon_type=icon_type, lang=lang) - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") + def _resolve_provider(self, provider: str) -> ProviderEntity | None: + return next((item for item in self.get_model_providers() if self._matches_provider(item, provider)), None) - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name + @staticmethod + def _matches_provider(provider_entity: ProviderEntity, provider: str) -> bool: + return provider in (provider_entity.provider, provider_entity.provider_name) diff --git a/api/dify_graph/model_runtime/runtime.py b/api/dify_graph/model_runtime/runtime.py new file mode 100644 index 0000000000..577bb6f387 --- /dev/null +++ b/api/dify_graph/model_runtime/runtime.py @@ -0,0 +1,159 @@ +from __future__ import annotations + +from collections.abc import Generator, Iterable, Sequence +from typing import IO, Any, Protocol, Union, runtime_checkable + +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from dify_graph.model_runtime.entities.provider_entities import ProviderEntity +from dify_graph.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult + + +@runtime_checkable +class ModelRuntime(Protocol): + """Port for provider discovery, schema lookup, and model execution. + + `provider` is the model runtime's canonical provider identifier. Adapters may + derive transport-specific details from it, but those details stay outside + this boundary. + """ + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: ... + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: ... + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: ... + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: ... + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: ... + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: ... + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: ... + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: ... + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: ... + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: ... + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: ... diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py index c85152463e..13abf74767 100644 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ b/api/dify_graph/model_runtime/utils/encoders.py @@ -1,7 +1,7 @@ import dataclasses import datetime from collections import defaultdict, deque -from collections.abc import Callable +from collections.abc import Callable, Sequence from decimal import Decimal from enum import Enum from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network @@ -99,7 +99,7 @@ def jsonable_encoder( exclude_defaults: bool = False, exclude_none: bool = False, custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, + excluded_key_prefixes: Sequence[str] = (), ) -> Any: custom_encoder = custom_encoder or {} if custom_encoder: @@ -126,7 +126,7 @@ def jsonable_encoder( obj_dict, exclude_none=exclude_none, exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if dataclasses.is_dataclass(obj): # Ensure obj is a dataclass instance, not a dataclass type @@ -139,7 +139,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) if isinstance(obj, Enum): return obj.value @@ -152,26 +152,28 @@ def jsonable_encoder( if isinstance(obj, dict): encoded_dict = {} for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value + if isinstance(key, str) and any(key.startswith(prefix) for prefix in excluded_key_prefixes): + continue + if value is None and exclude_none: + continue + + encoded_key = jsonable_encoder( + key, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_value = jsonable_encoder( + value, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + excluded_key_prefixes=excluded_key_prefixes, + ) + encoded_dict[encoded_key] = encoded_value return encoded_dict if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): encoded_list = [] @@ -184,7 +186,7 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) ) return encoded_list @@ -212,5 +214,5 @@ def jsonable_encoder( exclude_defaults=exclude_defaults, exclude_none=exclude_none, custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, + excluded_key_prefixes=excluded_key_prefixes, ) diff --git a/api/dify_graph/node_events/__init__.py b/api/dify_graph/node_events/__init__.py index a9bef8f9a2..a2bbf9f176 100644 --- a/api/dify_graph/node_events/__init__.py +++ b/api/dify_graph/node_events/__init__.py @@ -21,6 +21,7 @@ from .node import ( RunRetryEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) __all__ = [ @@ -43,4 +44,5 @@ __all__ = [ "RunRetryEvent", "StreamChunkEvent", "StreamCompletedEvent", + "VariableUpdatedEvent", ] diff --git a/api/dify_graph/node_events/node.py b/api/dify_graph/node_events/node.py index 2e3973b8fa..70aecc58b0 100644 --- a/api/dify_graph/node_events/node.py +++ b/api/dify_graph/node_events/node.py @@ -8,6 +8,7 @@ from dify_graph.entities.pause_reason import PauseReason from dify_graph.file import File from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult +from dify_graph.variables.variables import Variable from .base import NodeEventBase @@ -45,6 +46,12 @@ class StreamCompletedEvent(NodeEventBase): node_run_result: NodeRunResult = Field(..., description="run result") +class VariableUpdatedEvent(NodeEventBase): + """Notify the engine that a single variable should be applied to the shared pool.""" + + variable: Variable = Field(..., description="Updated variable payload to apply.") + + class PauseRequestedEvent(NodeEventBase): reason: PauseReason = Field(..., description="pause reason") diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py index 56b46a5894..adff2f6a93 100644 --- a/api/dify_graph/nodes/base/node.py +++ b/api/dify_graph/nodes/base/node.py @@ -4,15 +4,15 @@ import logging import operator from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence +from datetime import UTC, datetime from functools import singledispatchmethod from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin +from typing import Any, ClassVar, Generic, TypeVar, cast, get_args, get_origin from uuid import uuid4 from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import ( ErrorStrategy, NodeExecutionType, @@ -39,6 +39,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from dify_graph.node_events import ( AgentLogEvent, @@ -58,9 +59,9 @@ from dify_graph.node_events import ( RunRetrieverResourceEvent, StreamChunkEvent, StreamCompletedEvent, + VariableUpdatedEvent, ) from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) _MISSING_RUN_CONTEXT_VALUE = object() @@ -68,23 +69,6 @@ _MISSING_RUN_CONTEXT_VALUE = object() logger = logging.getLogger(__name__) -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - class Node(Generic[NodeDataT]): """BaseNode serves as the foundational class for all node implementations. @@ -177,8 +161,9 @@ class Node(Generic[NodeDataT]): # Skip base class itself if cls is Node: return - # Only register production node implementations defined under the - # canonical workflow namespaces. + # Only treat nodes from the base dify_graph package as production + # registrations. Higher-layer packages may still register subclasses, + # but dify_graph itself should not know their module identities. # This prevents test helper subclasses from polluting the global registry and # accidentally overriding real node types (e.g., a test Answer node). module_name = getattr(cls, "__module__", "") @@ -186,7 +171,7 @@ class Node(Generic[NodeDataT]): node_type = cls.node_type version = cls.version() bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): + if module_name.startswith("dify_graph.nodes."): # Production node definitions take precedence and may override bucket[version] = cls # type: ignore[index] else: @@ -263,16 +248,25 @@ class Node(Generic[NodeDataT]): self._node_id = node_id self._node_execution_id: str = "" - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) self._node_data = self.validate_node_data(config["data"]) self.post_init() @classmethod - def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model.""" - return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) + def validate_node_data(cls, node_data: BaseNodeData | Mapping[str, Any]) -> NodeDataT: + """Validate shared graph node payloads against the subclass-declared NodeData model. + + Re-validate from a dumped payload instead of `from_attributes=True` so compatibility + extras stored on `BaseNodeData` survive the handoff to the concrete node data model. + Human Input delivery methods are one such extra field until dify_graph owns that schema. + """ + if isinstance(node_data, BaseNodeData): + payload = node_data.model_dump(mode="python") + else: + payload = dict(node_data) + return cast(NodeDataT, cls._node_data_type.model_validate(payload)) def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" @@ -299,25 +293,6 @@ class Node(Generic[NodeDataT]): raise ValueError(f"run_context missing required key: {key}") return value - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - @property def execution_id(self) -> str: return self._node_execution_id @@ -364,7 +339,7 @@ class Node(Generic[NodeDataT]): def run(self) -> Generator[GraphNodeEventBase, None, None]: execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() + self._start_at = datetime.now(UTC).replace(tzinfo=None) # Create and push start event with required fields start_event = NodeRunStartedEvent( @@ -406,7 +381,7 @@ class Node(Generic[NodeDataT]): error=str(e), error_type="WorkflowNodeError", ) - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunFailedEvent( id=self.execution_id, node_id=self._node_id, @@ -570,7 +545,7 @@ class Node(Generic[NodeDataT]): return self._node_data def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match result.status: case WorkflowNodeExecutionStatus.FAILED: return NodeRunFailedEvent( @@ -611,7 +586,7 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = naive_utc_now() + finished_at = datetime.now(UTC).replace(tzinfo=None) match event.node_run_result.status: case WorkflowNodeExecutionStatus.SUCCEEDED: return NodeRunSucceededEvent( @@ -637,6 +612,15 @@ class Node(Generic[NodeDataT]): f"Node {self._node_id} does not support status {event.node_run_result.status}" ) + @_dispatch.register + def _(self, event: VariableUpdatedEvent) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id=self.execution_id, + node_id=self._node_id, + node_type=self.node_type, + variable=event.variable, + ) + @_dispatch.register def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: return NodeRunPauseRequestedEvent( @@ -793,16 +777,11 @@ class Node(Generic[NodeDataT]): @_dispatch.register def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] return NodeRunRetrieverResourceEvent( id=self.execution_id, node_id=self._node_id, node_type=self.node_type, - retriever_resources=retriever_resources, + retriever_resources=event.retriever_resources, context=event.context, node_version=self.version(), ) diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py index 892b0fc688..5798e563d9 100644 --- a/api/dify_graph/nodes/http_request/executor.py +++ b/api/dify_graph/nodes/http_request/executor.py @@ -246,7 +246,7 @@ class Executor: files: dict[str, list[tuple[str | None, bytes, str]]] = {} for key, files_in_segment in files_list: for file in files_in_segment: - if file.related_id is not None or ( + if file.reference is not None or ( file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None ): file_tuple = ( diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py index 3e5253d809..6f13085b36 100644 --- a/api/dify_graph/nodes/http_request/node.py +++ b/api/dify_graph/nodes/http_request/node.py @@ -11,9 +11,13 @@ from dify_graph.nodes.base import variable_template_parser from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.protocols import ( + FileManagerProtocol, + FileReferenceFactoryProtocol, + HttpClientProtocol, + ToolFileManagerProtocol, +) from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory from .config import build_http_request_config, resolve_http_request_config from .entities import ( @@ -46,6 +50,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): http_client: HttpClientProtocol, tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], file_manager: FileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, ) -> None: super().__init__( id=id, @@ -58,6 +63,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]): self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory self._file_manager = file_manager + self._file_reference_factory = file_reference_factory @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: @@ -212,7 +218,6 @@ class HttpRequestNode(Node[HttpRequestNodeData]): """ Extract files from response by checking both Content-Type header and URL """ - dify_ctx = self.require_dify_context() files: list[File] = [] is_file = response.is_file content_type = response.content_type @@ -237,20 +242,15 @@ class HttpRequestNode(Node[HttpRequestNodeData]): tool_file_manager = self._tool_file_manager_factory() tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - conversation_id=None, file_binary=content, mimetype=mime_type, ) - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, + file = self._file_reference_factory.build_from_mapping( + mapping={ + "tool_file_id": tool_file.id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } ) files.append(file) diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py index 2a33b4a0a8..675c1a0a3c 100644 --- a/api/dify_graph/nodes/human_input/entities.py +++ b/api/dify_graph/nodes/human_input/entities.py @@ -1,230 +1,27 @@ -""" -Human Input node entities. +"""Human Input node entities. + +The graph package owns the workflow-facing form schema and keeps it transportable +across runtimes. Dify-specific delivery surface and recipient translation stay +outside `dify_graph`. """ import re -import uuid from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self +from typing import Any, Self -import bleach -import markdown from pydantic import BaseModel, Field, field_validator, model_validator from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool from dify_graph.variables.consts import SELECTORS_LENGTH -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit +from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit _OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - class FormInputDefault(BaseModel): """Default configuration for form inputs.""" @@ -288,7 +85,6 @@ class HumanInputNodeData(BaseNodeData): """Human Input node data.""" type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) form_content: str = "" inputs: list[FormInput] = Field(default_factory=list) user_actions: list[UserAction] = Field(default_factory=list) @@ -317,14 +113,6 @@ class HumanInputNodeData(BaseNodeData): seen_ids.add(action_id) return user_actions - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - def expiration_time(self, start_time: datetime) -> datetime: if self.timeout_unit == TimeoutUnit.HOUR: return start_time + timedelta(hours=self.timeout) @@ -353,10 +141,6 @@ class HumanInputNodeData(BaseNodeData): _add_variable_selectors( [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) for input in self.inputs: default_value = input.default diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py index da85728828..3fb0ab4499 100644 --- a/api/dify_graph/nodes/human_input/enums.py +++ b/api/dify_graph/nodes/human_input/enums.py @@ -25,16 +25,6 @@ class HumanInputFormKind(enum.StrEnum): DELIVERY_TEST = enum.auto() # Form created for delivery tests. -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - class ButtonStyle(enum.StrEnum): """Button styles for user actions.""" @@ -63,10 +53,3 @@ class PlaceholderType(enum.StrEnum): VARIABLE = enum.auto() CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py index 794e33d92e..d08ca181e3 100644 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ b/api/dify_graph/nodes/human_input/human_input_node.py @@ -1,7 +1,8 @@ import json import logging from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, cast from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.entities.pause_reason import HumanInputRequired @@ -15,16 +16,11 @@ from dify_graph.node_events import ( from dify_graph.node_events.base import NodeEventBase from dify_graph.node_events.node import StreamCompletedEvent from dify_graph.nodes.base.node import Node -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) +from dify_graph.nodes.runtime import HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType +from .entities import HumanInputNodeData +from .enums import HumanInputFormStatus, PlaceholderType if TYPE_CHECKING: from dify_graph.entities.graph_init_params import GraphInitParams @@ -32,8 +28,6 @@ if TYPE_CHECKING: _SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" logger = logging.getLogger(__name__) @@ -56,7 +50,6 @@ class HumanInputNode(Node[HumanInputNodeData]): ) _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository _OUTPUT_FIELD_ACTION_ID = "__action_id" _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" @@ -67,7 +60,8 @@ class HumanInputNode(Node[HumanInputNodeData]): config: NodeConfigDict, graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, + runtime: HumanInputNodeRuntimeProtocol | None = None, + form_repository: object | None = None, ) -> None: super().__init__( id=id, @@ -75,7 +69,14 @@ class HumanInputNode(Node[HumanInputNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._form_repository = form_repository + resolved_runtime = runtime + if resolved_runtime is None: + raise ValueError("runtime is required") + if form_repository is not None: + with_form_repository = getattr(resolved_runtime, "with_form_repository", None) + if callable(with_form_repository): + resolved_runtime = cast(HumanInputNodeRuntimeProtocol, with_form_repository(form_repository)) + self._runtime: HumanInputNodeRuntimeProtocol = resolved_runtime @classmethod def version(cls) -> str: @@ -128,13 +129,7 @@ class HumanInputNode(Node[HumanInputNodeData]): return None - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): + def _form_to_pause_event(self, form_entity: HumanInputFormStateProtocol): required_event = self._human_input_required_event(form_entity) pause_requested_event = PauseRequestedEvent(reason=required_event) return pause_requested_event @@ -157,56 +152,16 @@ class HumanInputNode(Node[HumanInputNodeData]): return resolved_defaults - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] - - def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: + def _human_input_required_event(self, form_entity: HumanInputFormStateProtocol) -> HumanInputRequired: node_data = self._node_data resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") return HumanInputRequired( form_id=form_entity.id, form_content=form_entity.rendered_content, inputs=node_data.inputs, actions=node_data.user_actions, - display_in_ui=display_in_ui, node_id=self.id, node_title=node_data.title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -217,49 +172,32 @@ class HumanInputNode(Node[HumanInputNodeData]): This method will: 1. Generate a unique form ID 2. Create form content with variable substitution - 3. Create form in database + 3. Persist the form through the configured repository 4. Send form via configured delivery methods 5. Suspend workflow execution 6. Wait for form submission to resume """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() + form = self._runtime.get_form(node_id=self.id) if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=dify_ctx.app_id, - workflow_execution_id=self._workflow_execution_id, + form_entity = self._runtime.create_form( node_id=self.id, - form_config=self._node_data, + node_data=self._node_data, rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - dify_ctx.user_id - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, ) - form_entity = self._form_repository.create_form(params) - # Create human input required event logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, + "Human Input node suspended workflow for form. node_id=%s, form_id=%s", self.id, form_entity.id, ) yield self._form_to_pause_event(form_entity) return - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): + if form.status in { + HumanInputFormStatus.TIMEOUT, + HumanInputFormStatus.EXPIRED, + } or form.expiration_time <= datetime.now(UTC).replace(tzinfo=None): yield HumanInputFormTimeoutEvent( node_title=self._node_data.title, expiration_time=form.expiration_time, diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py index 7c0370e48c..6ca5684666 100644 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ b/api/dify_graph/nodes/if_else/if_else_node.py @@ -57,8 +57,8 @@ class IfElseNode(Node[IfElseNodeData]): break else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined + # TODO: Remove this once all graph definitions use the `cases` structure. + # Fallback to the legacy node shape when `cases` are not defined. input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py index d9947e09bc..7b6af61b9d 100644 --- a/api/dify_graph/nodes/iteration/exc.py +++ b/api/dify_graph/nodes/iteration/exc.py @@ -20,3 +20,7 @@ class IterationGraphNotFoundError(IterationNodeError): class IterationIndexNotFoundError(IterationNodeError): """Raised when the iteration index is not found.""" + + +class ChildGraphAbortedError(IterationNodeError): + """Raised when a child graph aborts and the container must stop immediately.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py index 033ec8672f..a24e6c33e5 100644 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ b/api/dify_graph/nodes/iteration/iteration_node.py @@ -1,12 +1,13 @@ import logging from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, ThreadPoolExecutor, as_completed +from contextlib import suppress from datetime import UTC, datetime +from threading import Lock from typing import TYPE_CHECKING, Any, NewType, cast from typing_extensions import TypeIs -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.enums import ( BuiltinNodeTypes, @@ -16,6 +17,7 @@ from dify_graph.enums import ( ) from dify_graph.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunSucceededEvent, @@ -36,10 +38,9 @@ from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeDa from dify_graph.runtime import VariablePool from dify_graph.variables import IntegerVariable, NoneSegment from dify_graph.variables.segments import ArrayAnySegment, ArraySegment -from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now from .exc import ( + ChildGraphAbortedError, InvalidIteratorValueError, IterationGraphNotFoundError, IterationIndexNotFoundError, @@ -49,10 +50,10 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.context import IExecutionContext from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) @@ -93,7 +94,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): self._validate_start_node() - started_at = naive_utc_now() + started_at = datetime.now(UTC).replace(tzinfo=None) iter_run_map: dict[str, float] = {} outputs: list[object] = [] usage_accumulator = [LLMUsage.empty_usage()] @@ -199,23 +200,14 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): graph_engine = self._create_graph_engine(index, item) # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool + try: + yield from self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs, + graph_engine=graph_engine, ) - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) + finally: + self._merge_graph_engine_usage(usage_accumulator=usage_accumulator, graph_engine=graph_engine) iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() def _execute_parallel_iterations( @@ -233,13 +225,15 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit all iteration tasks + started_child_engines: dict[int, GraphEngine] = {} + started_child_engines_lock = Lock() + merged_usage_indexes: set[int] = set() future_to_index: dict[ Future[ tuple[ float, list[GraphNodeEventBase], object | None, - dict[str, Variable], LLMUsage, ] ], @@ -248,10 +242,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): for index, item in enumerate(iterator_list_value): yield IterationNextEvent(index=index) future = executor.submit( - self._execute_single_iteration_parallel, + self._execute_tracked_iteration_parallel, index=index, item=item, - execution_context=self._capture_execution_context(), + started_child_engines=started_child_engines, + started_child_engines_lock=started_child_engines_lock, ) future_to_index[future] = index @@ -264,7 +259,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iteration_duration, events, output_value, - conversation_snapshot, iteration_usage, ) = result @@ -279,11 +273,31 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): iter_run_map[str(index)] = iteration_duration usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) + merged_usage_indexes.add(index) except Exception as e: + if index not in merged_usage_indexes: + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + if isinstance(e, ChildGraphAbortedError): + self._abort_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + reason=str(e) or _DEFAULT_CHILD_ABORT_REASON, + ) + self._drain_parallel_siblings( + future_to_index=future_to_index, + current_future=future, + started_child_engines=started_child_engines, + usage_accumulator=usage_accumulator, + merged_usage_indexes=merged_usage_indexes, + ) + raise e + # Handle errors based on error_handle_mode match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -301,48 +315,118 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs[:] = [output for output in outputs if output is not None] + @staticmethod + def _merge_graph_engine_usage( + *, + usage_accumulator: list[LLMUsage], + graph_engine: "GraphEngine | None", + ) -> None: + if graph_engine is None: + return + usage_accumulator[0] = IterationNode._merge_usage( + usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage + ) + + def _abort_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + reason: str, + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + + graph_engine = started_child_engines.get(index) + if graph_engine is not None: + graph_engine.request_abort(reason) + + future.cancel() + + def _drain_parallel_siblings( + self, + *, + future_to_index: Mapping[Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], int], + current_future: Future[tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]], + started_child_engines: Mapping[int, "GraphEngine"], + usage_accumulator: list[LLMUsage], + merged_usage_indexes: set[int], + ) -> None: + for future, index in future_to_index.items(): + if future == current_future: + continue + if future.cancelled(): + continue + + with suppress(Exception): + future.result() + + if index in merged_usage_indexes: + continue + + self._merge_graph_engine_usage( + usage_accumulator=usage_accumulator, + graph_engine=started_child_engines.get(index), + ) + merged_usage_indexes.add(index) + + def _execute_tracked_iteration_parallel( + self, + *, + index: int, + item: object, + started_child_engines: dict[int, "GraphEngine"], + started_child_engines_lock: Lock, + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + graph_engine = self._create_graph_engine(index, item) + with started_child_engines_lock: + started_child_engines[index] = graph_engine + + return self._execute_parallel_iteration_with_graph_engine( + index=index, + graph_engine=graph_engine, + ) + def _execute_single_iteration_parallel( self, index: int, item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: """Execute a single iteration in parallel mode and return results.""" - with execution_context: - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] + graph_engine = self._create_graph_engine(index, item) + return self._execute_parallel_iteration_with_graph_engine(index=index, graph_engine=graph_engine) - graph_engine = self._create_graph_engine(index, item) + def _execute_parallel_iteration_with_graph_engine( + self, + *, + index: int, + graph_engine: "GraphEngine", + ) -> tuple[float, list[GraphNodeEventBase], object | None, LLMUsage]: + """Execute a prepared child engine in parallel mode and return results.""" + iter_start_at = datetime.now(UTC).replace(tzinfo=None) + events: list[GraphNodeEventBase] = [] + outputs_temp: list[object] = [] - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) + # Collect events instead of yielding them directly + for event in self._run_single_iter( + variable_pool=graph_engine.graph_runtime_state.variable_pool, + outputs=outputs_temp, + graph_engine=graph_engine, + ): + events.append(event) - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() + # Get the output value from the temporary outputs list + output_value = outputs_temp[0] if outputs_temp else None + iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - return ( - iteration_duration, - events, - output_value, - conversation_snapshot, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() + return ( + iteration_duration, + events, + output_value, + graph_engine.graph_runtime_state.llm_usage, + ) def _handle_iteration_success( self, @@ -516,23 +600,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return variable_mapping - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: - parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) - - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) - def _append_iteration_info_to_event( self, event: GraphNodeEventBase, @@ -575,6 +642,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): else: outputs.append(result.to_object()) return + elif isinstance(event, GraphRunAbortedEvent): + raise ChildGraphAbortedError(event.reason or _DEFAULT_CHILD_ABORT_REASON) elif isinstance(event, GraphRunFailedEvent): match self.node_data.error_handle_mode: case ErrorHandleMode.TERMINATED: @@ -587,7 +656,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): def _create_graph_engine(self, index: int, item: object): from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState + from dify_graph.runtime import ChildGraphNotFoundError # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -602,14 +671,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): # append iteration variable (item, index) to variable pool variable_pool_copy.add([self._node_id, "index"], index) variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) root_node_id = self.node_data.start_node_id if root_node_id is None: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") @@ -618,9 +679,8 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, + variable_pool=variable_pool_copy, ) except ChildGraphNotFoundError as exc: raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py index 6ca01a21da..24b6035c44 100644 --- a/api/dify_graph/nodes/llm/entities.py +++ b/api/dify_graph/nodes/llm/entities.py @@ -3,11 +3,11 @@ from typing import Any, Literal from pydantic import BaseModel, Field, field_validator -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig class ModelConfig(BaseModel): diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py index 50e52a3b6f..e5df51c978 100644 --- a/api/dify_graph/nodes/llm/file_saver.py +++ b/api/dify_graph/nodes/llm/file_saver.py @@ -1,11 +1,9 @@ import mimetypes import typing as tp -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.file.constants import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol class LLMFileSaver(tp.Protocol): @@ -57,17 +55,20 @@ class LLMFileSaver(tp.Protocol): class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str + _tool_file_manager: ToolFileManagerProtocol + _file_reference_factory: FileReferenceFactoryProtocol - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id + def __init__( + self, + *, + tool_file_manager: ToolFileManagerProtocol, + file_reference_factory: FileReferenceFactoryProtocol, + http_client: HttpClientProtocol, + ): + self._tool_file_manager = tool_file_manager + self._file_reference_factory = file_reference_factory self._http_client = http_client - def _get_tool_file_manager(self): - return ToolFileManager() - def save_remote_url(self, url: str, file_type: FileType) -> File: http_response = self._http_client.get(url) http_response.raise_for_status() @@ -83,30 +84,24 @@ class FileSaverImpl(LLMFileSaver): file_type: FileType, extension_override: str | None = None, ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, + tool_file = self._tool_file_manager.create_file_by_raw( file_binary=data, mimetype=mime_type, ) extension_override = _validate_extension_override(extension_override) extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, + return self._file_reference_factory.build_from_mapping( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": tool_file.name, + "extension": extension, + "mime_type": mime_type, + "size": len(data), + "tool_file_id": str(tool_file.id), + "related_id": str(tool_file.id), + "storage_key": tool_file.file_key, + } ) diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py index 8682c3682c..add634e7b8 100644 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ b/api/dify_graph/nodes/llm/llm_utils.py @@ -4,9 +4,8 @@ import json import logging import re from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any -from core.model_manager import ModelInstance from dify_graph.file import FileType, file_manager from dify_graph.file.models import File from dify_graph.model_runtime.entities import ( @@ -24,9 +23,9 @@ from dify_graph.model_runtime.entities.message_entities import ( ) from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.nodes.base.entities import VariableSelector from dify_graph.runtime import VariablePool +from dify_graph.template_rendering import Jinja2TemplateRenderer from dify_graph.variables import ArrayFileSegment, FileSegment from dify_graph.variables.segments import ArrayAnySegment, NoneSegment @@ -37,7 +36,9 @@ from .exc import ( NoPromptFoundError, TemplateTypeNotSupportError, ) -from .protocols import TemplateRenderer +from .runtime_protocols import PreparedLLMProtocol + +CONTEXT_PLACEHOLDER = "{{#context#}}" logger = logging.getLogger(__name__) @@ -45,13 +46,10 @@ VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") MAX_RESOLVED_VALUE_LENGTH = 1024 -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) +def fetch_model_schema(*, model_instance: PreparedLLMProtocol) -> AIModelEntity: + model_schema = model_instance.get_model_schema() if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") + raise ValueError(f"Model schema not found for {getattr(model_instance, 'model_name', 'unknown model')}") return model_schema @@ -122,9 +120,9 @@ def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -133,7 +131,7 @@ def fetch_prompt_messages( variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: prompt_messages: list[PromptMessage] = [] model_schema = fetch_model_schema(model_instance=model_instance) @@ -285,11 +283,11 @@ def fetch_prompt_messages( def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: prompt_messages: list[PromptMessage] = [] for message in messages: @@ -308,7 +306,7 @@ def handle_list_messages( ) continue - template = message.text.replace("{#context#}", context) if context else message.text + template = message.text.replace(CONTEXT_PLACEHOLDER, context) segment_group = variable_pool.convert_template(template) file_contents: list[PromptMessageContentUnionTypes] = [] for segment in segment_group.value: @@ -343,7 +341,7 @@ def render_jinja2_message( template: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> str: if not template: return "" @@ -354,16 +352,16 @@ def render_jinja2_message( for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) + return template_renderer.render_template(template, jinja2_inputs) def handle_completion_template( *, template: LLMNodeCompletionModelPromptTemplate, - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, + template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: if template.edition_type == "jinja2": result_text = render_jinja2_message( @@ -373,7 +371,7 @@ def handle_completion_template( template_renderer=template_renderer, ) else: - template_text = template.text.replace("{#context#}", context) if context else template.text + template_text = template.text.replace(CONTEXT_PLACEHOLDER, context) result_text = variable_pool.convert_template(template_text).text return [ combine_message_content_with_role( @@ -399,7 +397,11 @@ def combine_message_content_with_role( raise NotImplementedError(f"Role {role} is not supported") -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: +def calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: rest_tokens = 2000 runtime_model_schema = fetch_model_schema(model_instance=model_instance) runtime_model_parameters = model_instance.parameters @@ -429,7 +431,7 @@ def handle_memory_chat_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> Sequence[PromptMessage]: if not memory or not memory_config: return [] @@ -444,7 +446,7 @@ def handle_memory_completion_mode( *, memory: PromptMessageMemory | None, memory_config: MemoryConfig | None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, ) -> str: if not memory or not memory_config: return "" diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py index a5492aee6b..6f51809d7f 100644 --- a/api/dify_graph/nodes/llm/node.py +++ b/api/dify_graph/nodes/llm/node.py @@ -7,30 +7,21 @@ import logging import re import time from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any, Literal, cast -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, NodeType, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod, FileType +from dify_graph.file import File, FileType, file_manager from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessage, + PromptMessageContentType, TextPromptMessageContent, ) from dify_graph.model_runtime.entities.llm_entities import ( @@ -41,7 +32,14 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessageContentUnionTypes, + PromptMessageRole, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ( @@ -55,19 +53,23 @@ from dify_graph.node_events import ( from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) from dify_graph.nodes.protocols import HttpClientProtocol +from dify_graph.prompt_entities import CompletionModelPromptTemplate, MemoryConfig from dify_graph.runtime import VariablePool +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError from dify_graph.variables import ( ArrayFileSegment, ArraySegment, + FileSegment, NoneSegment, ObjectSegment, StringSegment, ) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile from . import llm_utils from .entities import ( @@ -79,9 +81,12 @@ from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, LLMNodeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, VariableNotFoundError, ) -from .file_saver import FileSaverImpl, LLMFileSaver +from .file_saver import LLMFileSaver if TYPE_CHECKING: from dify_graph.file.models import File @@ -101,11 +106,12 @@ class LLMNode(Node[LLMNodeData]): _file_outputs: list[File] _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance + _retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None + _prompt_message_serializer: PromptMessageSerializerProtocol + _jinja2_template_renderer: Jinja2TemplateRenderer | None + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _default_query_selector: tuple[str, ...] | None def __init__( self, @@ -114,13 +120,16 @@ class LLMNode(Node[LLMNodeData]): graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState, *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol, + retriever_attachment_loader: RetrieverAttachmentLoaderProtocol | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, + default_query_selector: Sequence[str] | None = None, ): super().__init__( id=id, @@ -131,20 +140,15 @@ class LLMNode(Node[LLMNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory - self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer + self._retriever_attachment_loader = retriever_attachment_loader + self._jinja2_template_renderer = jinja2_template_renderer + self._default_query_selector = tuple(default_query_selector) if default_query_selector is not None else None @classmethod def version(cls) -> str: @@ -190,10 +194,11 @@ class LLMNode(Node[LLMNodeData]): generator = self._fetch_context(node_data=self.node_data) context = None context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event + if generator is not None: + for event in generator: + context = event.context + context_files = event.context_files or [] + yield event if context: node_inputs["#context#"] = context @@ -215,15 +220,17 @@ class LLMNode(Node[LLMNodeData]): query: str | None = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) + if ( + not query + and self._default_query_selector + and (query_variable := variable_pool.get(self._default_query_selector)) ): query = query_variable.text prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, - context=context, + context=context or "", memory=memory, model_instance=model_instance, stop=model_stop, @@ -234,7 +241,7 @@ class LLMNode(Node[LLMNodeData]): variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, context_files=context_files, - template_renderer=self._template_renderer, + jinja2_template_renderer=self._jinja2_template_renderer, ) # handle invoke result @@ -242,7 +249,6 @@ class LLMNode(Node[LLMNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=self.node_data.structured_output_enabled, structured_output=self.node_data.structured_output, file_saver=self._llm_file_saver, @@ -285,7 +291,7 @@ class LLMNode(Node[LLMNodeData]): process_data = { "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=self.node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -353,10 +359,9 @@ class LLMNode(Node[LLMNodeData]): @staticmethod def invoke_llm( *, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: Sequence[PromptMessage], stop: Sequence[str] | None = None, - user_id: str, structured_output_enabled: bool, structured_output: Mapping[str, Any] | None = None, file_saver: LLMFileSaver, @@ -367,35 +372,35 @@ class LLMNode(Node[LLMNodeData]): ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: model_parameters = model_instance.parameters invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] if structured_output_enabled: output_schema = LLMNode.fetch_structured_output_schema( structured_output=structured_output or {}, ) request_start_time = time.perf_counter() - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm_with_structured_output( + prompt_messages=prompt_messages, + json_schema=output_schema, + model_parameters=invoke_model_parameters, + stop=stop, + stream=True, + ), ) else: request_start_time = time.perf_counter() - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, + invoke_result = cast( + LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=invoke_model_parameters, + tools=None, + stop=stop, + stream=True, + ), ) return LLMNode.handle_invoke_result( @@ -404,6 +409,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs=file_outputs, node_id=node_id, node_type=node_type, + model_instance=model_instance, reasoning_format=reasoning_format, request_start_time=request_start_time, ) @@ -416,6 +422,7 @@ class LLMNode(Node[LLMNodeData]): file_outputs: list[File], node_id: str, node_type: NodeType, + model_instance: PreparedLLMProtocol | object, reasoning_format: Literal["separated", "tagged"] = "tagged", request_start_time: float | None = None, ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: @@ -487,8 +494,14 @@ class LLMNode(Node[LLMNodeData]): usage = result.delta.usage if finish_reason is None and result.delta.finish_reason: finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") + except Exception as e: + if hasattr(model_instance, "is_structured_output_parse_error") and cast( + PreparedLLMProtocol, model_instance + ).is_structured_output_parse_error(e): + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + if type(e).__name__ == "OutputParserError": + raise LLMNodeError(f"Failed to parse structured output: {e}") from e + raise # Extract reasoning content from tags in the main text full_text = full_text_buffer.getvalue() @@ -691,30 +704,8 @@ class LLMNode(Node[LLMNodeData]): segment_id = retriever_resource.get("segment_id") if not segment_id: continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) + if self._retriever_attachment_loader is not None: + context_files.extend(self._retriever_attachment_loader.load(segment_id=segment_id)) yield RunRetrieverResourceEvent( retriever_resources=original_retriever_resource, context=context_str.strip(), @@ -757,9 +748,9 @@ class LLMNode(Node[LLMNodeData]): *, sys_query: str | None = None, sys_files: Sequence[File], - context: str | None = None, + context: str = "", memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, stop: Sequence[str] | None = None, memory_config: MemoryConfig | None = None, @@ -768,24 +759,186 @@ class LLMNode(Node[LLMNodeData]): variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + + if isinstance(prompt_template, list): + # For chat model + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory messages for chat mode + memory_messages = _handle_memory_chat_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Extend prompt_messages with memory messages + prompt_messages.extend(memory_messages) + + # Add current query to the prompt messages + if sys_query: + message = LLMNodeChatModelMessage( + text=sys_query, + role=PromptMessageRole.USER, + edition_type="basic", + ) + prompt_messages.extend( + LLMNode.handle_list_messages( + messages=[message], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=vision_detail, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): + # For completion model + prompt_messages.extend( + _handle_completion_template( + template=prompt_template, + context=context, + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + ) + + # Get memory text for completion model + memory_text = _handle_memory_completion_mode( + memory=memory, + memory_config=memory_config, + model_instance=model_instance, + ) + # Insert histories into the prompt + prompt_content = prompt_messages[0].content + # For issue #11247 - Check if prompt content is a string or a list + if isinstance(prompt_content, str): + prompt_content = str(prompt_content) + if "#histories#" in prompt_content: + prompt_content = prompt_content.replace("#histories#", memory_text) + else: + prompt_content = memory_text + "\n" + prompt_content + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + if "#histories#" in content_item.data: + content_item.data = content_item.data.replace("#histories#", memory_text) + else: + content_item.data = memory_text + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + + # Add current query to the prompt message + if sys_query: + if isinstance(prompt_content, str): + prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) + prompt_messages[0].content = prompt_content + elif isinstance(prompt_content, list): + for content_item in prompt_content: + if isinstance(content_item, TextPromptMessageContent): + content_item.data = sys_query + "\n" + content_item.data + else: + raise ValueError("Invalid prompt content type") + else: + raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) + + # The sys_files will be deprecated later + if vision_enabled and sys_files: + file_prompts = [] + for file in sys_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # The context_files + if vision_enabled and context_files: + file_prompts = [] + for file in context_files: + file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) + file_prompts.append(file_prompt) + # If last prompt is a user prompt, add files into its contents, + # otherwise append a new user prompt + if ( + len(prompt_messages) > 0 + and isinstance(prompt_messages[-1], UserPromptMessage) + and isinstance(prompt_messages[-1].content, list) + ): + prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content) + else: + prompt_messages.append(UserPromptMessage(content=file_prompts)) + + # Remove empty messages and filter unsupported content + filtered_prompt_messages = [] + for prompt_message in prompt_messages: + if isinstance(prompt_message.content, list): + prompt_message_content: list[PromptMessageContentUnionTypes] = [] + for content_item in prompt_message.content: + # Skip content if features are not defined + if not model_schema.features: + if content_item.type != PromptMessageContentType.TEXT: + continue + prompt_message_content.append(content_item) + continue + + # Skip content if corresponding feature is not supported + if ( + ( + content_item.type == PromptMessageContentType.IMAGE + and ModelFeature.VISION not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.DOCUMENT + and ModelFeature.DOCUMENT not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.VIDEO + and ModelFeature.VIDEO not in model_schema.features + ) + or ( + content_item.type == PromptMessageContentType.AUDIO + and ModelFeature.AUDIO not in model_schema.features + ) + ): + continue + prompt_message_content.append(content_item) + if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: + prompt_message.content = prompt_message_content[0].data + else: + prompt_message.content = prompt_message_content + if prompt_message.is_empty(): + continue + filtered_prompt_messages.append(prompt_message) + + if len(filtered_prompt_messages) == 0: + raise NoPromptFoundError( + "No prompt found in the LLM configuration. " + "Please ensure a prompt is properly configured before proceeding." + ) + + return filtered_prompt_messages, stop @classmethod def _extract_variable_selector_to_variable_mapping( @@ -829,9 +982,6 @@ class LLMNode(Node[LLMNodeData]): if node_data.vision.enabled: variable_mapping["#files#"] = node_data.vision.configs.variable_selector - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - if node_data.prompt_config: enable_jinja = False @@ -881,20 +1031,62 @@ class LLMNode(Node[LLMNodeData]): def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], - context: str | None, + context: str, jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) + prompt_messages: list[PromptMessage] = [] + for message in messages: + if message.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=message.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) + prompt_messages.append(prompt_message) + else: + # Get segment group from basic message + template = message.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + segment_group = variable_pool.convert_template(template) + + # Process segments for images + file_contents = [] + for segment in segment_group.value: + if isinstance(segment, ArrayFileSegment): + for file in segment.value: + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + elif isinstance(segment, FileSegment): + file = segment.value + if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: + file_content = file_manager.to_prompt_message_content( + file, image_detail_config=vision_detail_config + ) + file_contents.append(file_content) + + # Create message with text from all segments + plain_text = segment_group.text + if plain_text: + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) + prompt_messages.append(prompt_message) + + if file_contents: + # Create message with image contents + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) + prompt_messages.append(prompt_message) + + return prompt_messages @staticmethod def handle_blocking_result( @@ -1031,5 +1223,150 @@ class LLMNode(Node[LLMNodeData]): return self.node_data.retry_config.retry_enabled @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance + + +def _combine_message_content_with_role( + *, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole +): + match role: + case PromptMessageRole.USER: + return UserPromptMessage(content=contents) + case PromptMessageRole.ASSISTANT: + return AssistantPromptMessage(content=contents) + case PromptMessageRole.SYSTEM: + return SystemPromptMessage(content=contents) + case _: + raise NotImplementedError(f"Role {role} is not supported") + + +def _render_jinja2_message( + *, + template: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None, +): + if not template: + return "" + + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: + variable = variable_pool.get(jinja2_variable.value_selector) + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + if jinja2_template_renderer is None: + raise TemplateRenderError("LLMNode requires an injected jinja2_template_renderer for jinja2 prompts.") + return jinja2_template_renderer.render_template(template, jinja2_inputs) + + +def _calculate_rest_token( + *, + prompt_messages: list[PromptMessage], + model_instance: PreparedLLMProtocol, +) -> int: + rest_tokens = 2000 + runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) + runtime_model_parameters = model_instance.parameters + + model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) + if model_context_tokens: + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + + max_tokens = 0 + for parameter_rule in runtime_model_schema.parameter_rules: + if parameter_rule.name == "max_tokens" or ( + parameter_rule.use_template and parameter_rule.use_template == "max_tokens" + ): + max_tokens = ( + runtime_model_parameters.get(parameter_rule.name) + or runtime_model_parameters.get(str(parameter_rule.use_template)) + or 0 + ) + + rest_tokens = model_context_tokens - max_tokens - curr_message_tokens + rest_tokens = max(rest_tokens, 0) + + return rest_tokens + + +def _handle_memory_chat_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> Sequence[PromptMessage]: + memory_messages: Sequence[PromptMessage] = [] + # Get messages from memory for chat model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + memory_messages = memory.get_history_prompt_messages( + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + ) + return memory_messages + + +def _handle_memory_completion_mode( + *, + memory: PromptMessageMemory | None, + memory_config: MemoryConfig | None, + model_instance: PreparedLLMProtocol, +) -> str: + memory_text = "" + # Get history text from memory for completion model + if memory and memory_config: + rest_tokens = _calculate_rest_token( + prompt_messages=[], + model_instance=model_instance, + ) + if not memory_config.role_prefix: + raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=rest_tokens, + message_limit=memory_config.window.size if memory_config.window.enabled else None, + human_prefix=memory_config.role_prefix.user, + ai_prefix=memory_config.role_prefix.assistant, + ) + return memory_text + + +def _handle_completion_template( + *, + template: LLMNodeCompletionModelPromptTemplate, + context: str, + jinja2_variables: Sequence[VariableSelector], + variable_pool: VariablePool, + jinja2_template_renderer: Jinja2TemplateRenderer | None = None, +) -> Sequence[PromptMessage]: + """Handle completion template processing outside of LLMNode class. + + Args: + template: The completion model prompt template + context: Context string + jinja2_variables: Variables for jinja2 template rendering + variable_pool: Variable pool for template conversion + + Returns: + Sequence of prompt messages + """ + prompt_messages = [] + if template.edition_type == "jinja2": + result_text = _render_jinja2_message( + template=template.jinja2_text or "", + jinja2_variables=jinja2_variables, + variable_pool=variable_pool, + jinja2_template_renderer=jinja2_template_renderer, + ) + else: + template_text = template.text.replace(llm_utils.CONTEXT_PLACEHOLDER, context) + result_text = variable_pool.convert_template(template_text).text + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) + prompt_messages.append(prompt_message) + return prompt_messages diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py index 9e95d341c9..e1b01b2b1e 100644 --- a/api/dify_graph/nodes/llm/protocols.py +++ b/api/dify_graph/nodes/llm/protocols.py @@ -1,9 +1,8 @@ from __future__ import annotations -from collections.abc import Mapping from typing import Any, Protocol -from core.model_manager import ModelInstance +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol class CredentialsProvider(Protocol): @@ -15,16 +14,8 @@ class CredentialsProvider(Protocol): class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" + """Port for creating prepared graph-facing LLM runtimes for execution.""" - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" - ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" + def init_model_instance(self, provider_name: str, model_name: str) -> PreparedLLMProtocol: + """Create a prepared LLM runtime that is ready for graph execution.""" ... diff --git a/api/dify_graph/nodes/llm/runtime_protocols.py b/api/dify_graph/nodes/llm/runtime_protocols.py new file mode 100644 index 0000000000..f7d2322289 --- /dev/null +++ b/api/dify_graph/nodes/llm/runtime_protocols.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from typing import Any, Protocol + +from dify_graph.file import File +from dify_graph.model_runtime.entities import LLMMode, PromptMessage +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from dify_graph.model_runtime.entities.message_entities import PromptMessageTool +from dify_graph.model_runtime.entities.model_entities import AIModelEntity + + +class PreparedLLMProtocol(Protocol): + """A graph-facing LLM runtime with provider-specific setup already applied.""" + + @property + def provider(self) -> str: ... + + @property + def model_name(self) -> str: ... + + @property + def parameters(self) -> Mapping[str, Any]: ... + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: ... + + @property + def stop(self) -> Sequence[str] | None: ... + + def get_model_schema(self) -> AIModelEntity: ... + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: ... + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def is_structured_output_parse_error(self, error: Exception) -> bool: ... + + +class PromptMessageSerializerProtocol(Protocol): + """Port for converting compiled prompt messages into persisted process data.""" + + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: ... + + +class RetrieverAttachmentLoaderProtocol(Protocol): + """Port for resolving retriever segment attachments into graph file references.""" + + def load(self, *, segment_id: str) -> Sequence[File]: ... diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py index 3c546ffa23..8542972a30 100644 --- a/api/dify_graph/nodes/loop/loop_node.py +++ b/api/dify_graph/nodes/loop/loop_node.py @@ -2,7 +2,7 @@ import contextlib import json import logging from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime +from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Literal, cast from dify_graph.entities.graph_config import NodeConfigDictAdapter @@ -14,6 +14,7 @@ from dify_graph.enums import ( ) from dify_graph.graph_events import ( GraphNodeEventBase, + GraphRunAbortedEvent, GraphRunFailedEvent, NodeRunSucceededEvent, ) @@ -31,14 +32,13 @@ from dify_graph.nodes.base import LLMUsageTrackingMixin from dify_graph.nodes.base.node import Node from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.variables import Segment, SegmentType -from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now +from dify_graph.variables import Segment, SegmentType, TypeMismatchError, build_segment_with_type, segment_to_variable if TYPE_CHECKING: from dify_graph.graph_engine import GraphEngine logger = logging.getLogger(__name__) +_DEFAULT_CHILD_ABORT_REASON = "child graph aborted" class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): @@ -91,7 +91,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): loop_variable_selectors[loop_variable.label] = variable_selector inputs[loop_variable.label] = processed_segment.value - start_at = naive_utc_now() + start_at = datetime.now(UTC).replace(tzinfo=None) condition_processor = ConditionProcessor() loop_duration_map: dict[str, float] = {} @@ -124,10 +124,13 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): self._clear_loop_subgraph_variables(loop_node_ids) graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - loop_start_time = naive_utc_now() - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + loop_start_time = datetime.now(UTC).replace(tzinfo=None) + try: + reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) + finally: + loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() + loop_duration_map[str(i)] = (datetime.now(UTC).replace(tzinfo=None) - loop_start_time).total_seconds() # Accumulate outputs from the sub-graph's response nodes for key, value in graph_engine.graph_runtime_state.outputs.items(): @@ -142,9 +145,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): # For other outputs, just update self.graph_runtime_state.set_output(key, value) - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - # Collect loop variable values after iteration single_loop_variable = {} for key, selector in loop_variable_selectors.items(): @@ -256,6 +256,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): yield event if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: reach_break_node = True + if isinstance(event, GraphRunAbortedEvent): + raise RuntimeError(event.reason or _DEFAULT_CHILD_ABORT_REASON) if isinstance(event, GraphRunFailedEvent): raise Exception(event.error) @@ -410,7 +412,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): def _create_graph_engine(self, start_at: datetime, root_node_id: str): from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState # Create GraphInitParams for child graph execution. graph_init_params = GraphInitParams( @@ -420,16 +421,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): call_depth=self.workflow_call_depth, ) - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - return self.graph_runtime_state.create_child_engine( workflow_id=self.workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, root_node_id=root_node_id, ) diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py index 2fb042c16c..bc6c85c8bc 100644 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ b/api/dify_graph/nodes/parameter_extractor/entities.py @@ -7,10 +7,10 @@ from pydantic import ( field_validator, ) -from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig +from dify_graph.prompt_entities import MemoryConfig from dify_graph.variables.types import SegmentType _OLD_BOOL_TYPE_NAME = "bool" diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py index e6e8a44d06..25715f9d92 100644 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py @@ -5,11 +5,6 @@ import uuid from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, cast -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, @@ -17,8 +12,8 @@ from dify_graph.enums import ( WorkflowNodeExecutionStatus, ) from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode +from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, @@ -27,17 +22,18 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey, ModelType from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base import variable_template_parser from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils +from dify_graph.nodes.llm import LLMNode, llm_utils +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol from dify_graph.runtime import VariablePool +from dify_graph.variables import build_segment_with_type from dify_graph.variables.types import ArrayValidation, SegmentType -from factories.variable_factory import build_segment_with_type from .entities import ParameterExtractorNodeData from .exc import ( @@ -66,7 +62,6 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.runtime import GraphRuntimeState @@ -99,9 +94,8 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" + _model_instance: PreparedLLMProtocol + _prompt_message_serializer: PromptMessageSerializerProtocol _memory: PromptMessageMemory | None def __init__( @@ -111,10 +105,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None = None, + prompt_message_serializer: PromptMessageSerializerProtocol, ) -> None: super().__init__( id=id, @@ -122,9 +117,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory self._model_instance = model_instance + self._prompt_message_serializer = prompt_message_serializer self._memory = memory @classmethod @@ -168,13 +163,12 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): model_instance.parameters = llm_utils.resolve_completion_params_variables( model_instance.parameters, variable_pool ) - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc + if model_schema.model_type != ModelType.LLM: + raise InvalidModelTypeError("Model is not a Large Language Model") memory = self._memory if ( @@ -214,8 +208,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages + "prompts": self._prompt_message_serializer.serialize( + model_mode=node_data.model.mode, + prompt_messages=prompt_messages, ), "usage": None, "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), @@ -291,18 +286,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): def _invoke( self, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, prompt_messages: list[PromptMessage], tools: list[PromptMessageTool], - stop: Sequence[str], + stop: Sequence[str] | None, ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, + invoke_result = cast( + LLMResult, + model_instance.invoke_llm( + prompt_messages=prompt_messages, + model_parameters=dict(model_instance.parameters), + tools=tools or None, + stop=stop, + stream=False, + ), ) # handle invoke result @@ -321,7 +318,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -333,7 +330,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): content=query, structure=json.dumps(node_data.get_parameter_json_schema()) ) - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -344,15 +340,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_function_calling_prompt_template( node_data, query, variable_pool, memory, rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -409,7 +401,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -417,9 +409,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate prompt engineering prompt. """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: + if data.model.mode == LLMMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( node_data=data, query=query, @@ -429,7 +419,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - elif model_mode == ModelMode.CHAT: + if data.model.mode == LLMMode.CHAT: return self._generate_prompt_engineering_chat_prompt( node_data=data, query=query, @@ -439,15 +429,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): files=files, vision_detail=vision_detail, ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") + raise InvalidModelModeError(f"Invalid model mode: {data.model.mode}") def _generate_prompt_engineering_completion_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -455,7 +444,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate completion prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -466,27 +454,20 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): prompt_template = self._get_prompt_engineering_prompt_template( node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), + return self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) - return prompt_messages - def _generate_prompt_engineering_chat_prompt( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, memory: PromptMessageMemory | None, files: Sequence[File], vision_detail: ImagePromptMessageContent.DETAIL | None = None, @@ -494,7 +475,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): """ Generate chat prompt. """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) rest_token = self._calculate_rest_token( node_data=node_data, query=query, @@ -512,15 +492,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): max_token_limit=rest_token, ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=files, + vision_enabled=node_data.vision.enabled, image_detail_config=vision_detail, ) @@ -721,8 +697,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage]: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -731,15 +706,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _get_prompt_engineering_prompt_template( self, @@ -748,8 +722,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): variable_pool: VariablePool, memory: PromptMessageMemory | None, max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) + ) -> list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -758,64 +731,54 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): memory_str = llm_utils.fetch_memory_text( memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( + if node_data.model.mode == LLMMode.CHAT: + system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) + user_prompt_message = LLMNodeChatModelMessage(role=PromptMessageRole.USER, text=input_text) return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( + if node_data.model.mode == LLMMode.COMPLETION: + return LLMNodeCompletionModelPromptTemplate( text=COMPLETION_GENERATE_JSON_PROMPT.format( histories=memory_str, text=input_text, instruction=instruction ) .replace("{γγγ", "") .replace("}γγγ", "") + .replace("{ structure }", json.dumps(node_data.get_parameter_json_schema())), ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") + raise InvalidModelModeError(f"Model mode {node_data.model.mode} not support.") def _calculate_rest_token( self, node_data: ParameterExtractorNodeData, query: str, variable_pool: VariablePool, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: try: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) except ValueError as exc: raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) + prompt_template: list[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) else: prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, + prompt_messages = self._compile_prompt_messages( model_instance=model_instance, + prompt_template=prompt_template, + files=[], + vision_enabled=False, + context=context, ) rest_tokens = 2000 - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages + curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) + 1000 max_tokens = 0 for parameter_rule in model_schema.parameter_rules: @@ -832,8 +795,34 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]): return rest_tokens + def _compile_prompt_messages( + self, + *, + model_instance: PreparedLLMProtocol, + prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, + files: Sequence[File], + vision_enabled: bool, + context: str | None = "", + image_detail_config: ImagePromptMessageContent.DETAIL | None = None, + ) -> list[PromptMessage]: + prompt_messages, _ = LLMNode.fetch_prompt_messages( + sys_query="", + sys_files=files, + context=context or "", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=model_instance.stop, + memory_config=None, + vision_enabled=vision_enabled, + vision_detail=image_detail_config or ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=self.graph_runtime_state.variable_pool, + jinja2_variables=[], + ) + return list(prompt_messages) + @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py index 62d3bcdca1..b2046b8da8 100644 --- a/api/dify_graph/nodes/protocols.py +++ b/api/dify_graph/nodes/protocols.py @@ -1,10 +1,9 @@ -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Protocol import httpx from dify_graph.file import File -from dify_graph.file.models import ToolFile class HttpClientProtocol(Protocol): @@ -35,12 +34,13 @@ class ToolFileManagerProtocol(Protocol): def create_file_by_raw( self, *, - user_id: str, - tenant_id: str, - conversation_id: str | None, file_binary: bytes, mimetype: str, filename: str | None = None, ) -> Any: ... - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: ... + + +class FileReferenceFactoryProtocol(Protocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py index 0c1601d439..04c8e76ef9 100644 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ b/api/dify_graph/nodes/question_classifier/entities.py @@ -1,9 +1,9 @@ from pydantic import BaseModel, Field -from core.prompt.entities.advanced_prompt_entities import MemoryConfig from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType from dify_graph.nodes.llm import ModelConfig, VisionConfig +from dify_graph.prompt_entities import MemoryConfig class ClassConfig(BaseModel): diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py index 928618fdbc..6c2a20d18a 100644 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ b/api/dify_graph/nodes/question_classifier/question_classifier_node.py @@ -3,9 +3,6 @@ import re from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( @@ -14,7 +11,7 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole +from dify_graph.model_runtime.entities import LLMMode, LLMUsage, ModelPropertyKey, PromptMessageRole from dify_graph.model_runtime.memory import PromptMessageMemory from dify_graph.model_runtime.utils.encoders import jsonable_encoder from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult @@ -27,10 +24,11 @@ from dify_graph.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.runtime_protocols import PreparedLLMProtocol, PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown +from dify_graph.template_rendering import Jinja2TemplateRenderer +from dify_graph.utils.json_in_md_parser import parse_and_check_json_markdown from .entities import QuestionClassifierNodeData from .exc import InvalidModelTypeError @@ -49,17 +47,22 @@ if TYPE_CHECKING: from dify_graph.runtime import GraphRuntimeState +class _PassthroughPromptMessageSerializer: + def serialize(self, *, model_mode: Any, prompt_messages: Sequence[Any]) -> Any: + _ = model_mode + return list(prompt_messages) + + class QuestionClassifierNode(Node[QuestionClassifierNodeData]): node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER execution_type = NodeExecutionType.BRANCH _file_outputs: list["File"] _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance + _prompt_message_serializer: PromptMessageSerializerProtocol + _model_instance: PreparedLLMProtocol _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer + _template_renderer: Jinja2TemplateRenderer def __init__( self, @@ -68,13 +71,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, + credentials_provider: object | None = None, + model_factory: object | None = None, + model_instance: PreparedLLMProtocol, http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, + template_renderer: Jinja2TemplateRenderer, memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, + llm_file_saver: LLMFileSaver, + prompt_message_serializer: PromptMessageSerializerProtocol | None = None, ): super().__init__( id=id, @@ -85,20 +89,13 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] - self._credentials_provider = credentials_provider - self._model_factory = model_factory + _ = credentials_provider, model_factory, http_client self._model_instance = model_instance self._memory = memory self._template_renderer = template_renderer - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) self._llm_file_saver = llm_file_saver + self._prompt_message_serializer = prompt_message_serializer or _PassthroughPromptMessageSerializer() @classmethod def version(cls): @@ -173,7 +170,6 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, - user_id=self.require_dify_context().user_id, structured_output_enabled=False, structured_output=None, file_saver=self._llm_file_saver, @@ -209,7 +205,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): category_id = category_id_result process_data = { "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( + "prompts": self._prompt_message_serializer.serialize( model_mode=node_data.model.mode, prompt_messages=prompt_messages ), "usage": jsonable_encoder(usage), @@ -251,7 +247,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) @property - def model_instance(self) -> ModelInstance: + def model_instance(self) -> PreparedLLMProtocol: return self._model_instance @classmethod @@ -289,7 +285,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): self, node_data: QuestionClassifierNodeData, query: str, - model_instance: ModelInstance, + model_instance: PreparedLLMProtocol, context: str | None, ) -> int: model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) @@ -299,7 +295,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): prompt_template=prompt_template, sys_query="", sys_files=[], - context=context, + context=context or "", memory=None, model_instance=model_instance, stop=model_instance.stop, @@ -338,7 +334,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): memory: PromptMessageMemory | None, max_token_limit: int = 2000, ): - model_mode = ModelMode(node_data.model.mode) + model_mode = LLMMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: @@ -354,7 +350,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, ) prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: + if model_mode == LLMMode.CHAT: system_prompt_messages = LLMNodeChatModelMessage( role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) ) @@ -385,7 +381,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]): ) prompt_messages.append(user_prompt_message_3) return prompt_messages - elif model_mode == ModelMode.COMPLETION: + elif model_mode == LLMMode.COMPLETION: return LLMNodeCompletionModelPromptTemplate( text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( histories=memory_str, diff --git a/api/dify_graph/nodes/runtime.py b/api/dify_graph/nodes/runtime.py new file mode 100644 index 0000000000..528592e025 --- /dev/null +++ b/api/dify_graph/nodes/runtime.py @@ -0,0 +1,106 @@ +from __future__ import annotations + +from collections.abc import Generator, Mapping, Sequence +from datetime import datetime +from typing import TYPE_CHECKING, Any, Protocol + +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) + +if TYPE_CHECKING: + from dify_graph.nodes.human_input.entities import HumanInputNodeData + from dify_graph.nodes.human_input.enums import HumanInputFormStatus + from dify_graph.nodes.tool.entities import ToolNodeData + from dify_graph.runtime import VariablePool + + +class ToolNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter owned by `core.workflow` and consumed by `dify_graph`. + + The graph package depends only on these DTOs and lets the workflow layer + translate between graph-owned abstractions and `core.tools` internals. + """ + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool: VariablePool | None, + ) -> ToolRuntimeHandle: ... + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: ... + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: ... + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: ... + + def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: ... + + +class HumanInputNodeRuntimeProtocol(Protocol): + """Workflow-layer adapter for human-input runtime persistence and delivery.""" + + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: ... + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: ... + + +class HumanInputFormStateProtocol(Protocol): + @property + def id(self) -> str: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... diff --git a/api/dify_graph/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py index 5e6055ea34..8d0589f150 100644 --- a/api/dify_graph/nodes/start/start_node.py +++ b/api/dify_graph/nodes/start/start_node.py @@ -2,7 +2,6 @@ from typing import Any from jsonschema import Draft7Validator, ValidationError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult from dify_graph.nodes.base.node import Node @@ -19,15 +18,10 @@ class StartNode(Node[StartNodeData]): return "1" def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - outputs = dict(node_inputs) + outputs = dict(self.graph_runtime_state.variable_pool.flatten(unprefixed_node_id=self.id)) + outputs.update(node_inputs) return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py deleted file mode 100644 index 9b679d4497..0000000000 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage - - -class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" - - -class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py index dc6fce2b0a..2f90f80112 100644 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ b/api/dify_graph/nodes/template_transform/template_transform_node.py @@ -4,9 +4,10 @@ from typing import TYPE_CHECKING, Any from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult +from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.base.node import Node from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( +from dify_graph.template_rendering import ( Jinja2TemplateRenderer, TemplateRenderError, ) @@ -20,7 +21,7 @@ DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 class TemplateTransformNode(Node[TemplateTransformNodeData]): node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer + _jinja2_template_renderer: Jinja2TemplateRenderer _max_output_length: int def __init__( @@ -30,7 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", *, - template_renderer: Jinja2TemplateRenderer, + jinja2_template_renderer: Jinja2TemplateRenderer, max_output_length: int | None = None, ) -> None: super().__init__( @@ -39,7 +40,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) - self._template_renderer = template_renderer + self._jinja2_template_renderer = jinja2_template_renderer if max_output_length is not None and max_output_length <= 0: raise ValueError("max_output_length must be a positive integer") @@ -70,7 +71,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): variables[variable_name] = value.to_object() if value else None # Run code try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) + rendered = self._jinja2_template_renderer.render_template(self.node_data.template, variables) except TemplateRenderError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -87,9 +88,32 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, + *, + graph_config: Mapping[str, Any], + node_id: str, + node_data: TemplateTransformNodeData | Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } + _ = graph_config + raw_variables = ( + node_data.variables if isinstance(node_data, TemplateTransformNodeData) else node_data.get("variables", []) + ) + variable_mapping: dict[str, Sequence[str]] = {} + for variable_selector in raw_variables: + if isinstance(variable_selector, VariableSelector): + variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector + continue + + if not isinstance(variable_selector, Mapping): + continue + + variable = variable_selector.get("variable") + value_selector = variable_selector.get("value_selector") + if ( + isinstance(variable, str) + and isinstance(value_selector, Sequence) + and all(isinstance(selector_part, str) for selector_part in value_selector) + ): + variable_mapping[node_id + "." + variable] = list(value_selector) + + return variable_mapping diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py index b041ee66fd..fc322fd912 100644 --- a/api/dify_graph/nodes/tool/entities.py +++ b/api/dify_graph/nodes/tool/entities.py @@ -1,13 +1,27 @@ +from enum import StrEnum, auto from typing import Any, Literal, Union from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from core.tools.entities.tool_entities import ToolProviderType from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, NodeType +class ToolProviderType(StrEnum): + """ + Graph-owned enum for persisted tool provider kinds. + """ + + PLUGIN = auto() + BUILT_IN = "builtin" + WORKFLOW = auto() + API = auto() + APP = auto() + DATASET_RETRIEVAL = "dataset-retrieval" + MCP = auto() + + class ToolEntity(BaseModel): provider_id: str provider_type: ToolProviderType diff --git a/api/dify_graph/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py index 7212e8bfc0..1a309e1084 100644 --- a/api/dify_graph/nodes/tool/exc.py +++ b/api/dify_graph/nodes/tool/exc.py @@ -4,6 +4,18 @@ class ToolNodeError(ValueError): pass +class ToolRuntimeResolutionError(ToolNodeError): + """Raised when the workflow layer cannot construct a tool runtime.""" + + pass + + +class ToolRuntimeInvocationError(ToolNodeError): + """Raised when the workflow layer fails while invoking a tool runtime.""" + + pass + + class ToolParameterError(ToolNodeError): """Exception raised for errors in tool parameters.""" diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py index 598f0da92e..ce8ceb999c 100644 --- a/api/dify_graph/nodes/tool/tool_node.py +++ b/api/dify_graph/nodes/tool/tool_node.py @@ -1,29 +1,25 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.file import File, FileTransferMethod +from dify_graph.file import File, FileTransferMethod, get_file_type_by_mime_type from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from dify_graph.nodes.runtime import ToolNodeRuntimeProtocol +from dify_graph.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from dify_graph.variables.segments import ArrayFileSegment from .entities import ToolNodeData from .exc import ( @@ -52,6 +48,7 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state: "GraphRuntimeState", *, tool_file_manager_factory: ToolFileManagerProtocol, + runtime: ToolNodeRuntimeProtocol | None = None, ): super().__init__( id=id, @@ -60,6 +57,9 @@ class ToolNode(Node[ToolNodeData]): graph_runtime_state=graph_runtime_state, ) self._tool_file_manager_factory = tool_file_manager_factory + if runtime is None: + raise ValueError("runtime is required") + self._runtime = runtime @classmethod def version(cls) -> str: @@ -73,10 +73,6 @@ class ToolNode(Node[ToolNodeData]): """ Run the tool node """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - # fetch tool icon tool_info = { "provider_type": self.node_data.provider_type.value, @@ -86,8 +82,6 @@ class ToolNode(Node[ToolNodeData]): # get tool runtime try: - from core.tools.tool_manager import ToolManager - # This is an issue that caused problems before. # Logically, we shouldn't use the node_data.version field for judgment # But for backward compatibility with historical data @@ -95,13 +89,10 @@ class ToolNode(Node[ToolNodeData]): variable_pool: VariablePool | None = None if self.node_data.version != "1" or self.node_data.tool_node_version is not None: variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, + tool_runtime = self._runtime.get_runtime( + node_id=self._node_id, + node_data=self.node_data, + variable_pool=variable_pool, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -116,7 +107,7 @@ class ToolNode(Node[ToolNodeData]): return # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] + tool_parameters = self._runtime.get_runtime_parameters(tool_runtime=tool_runtime) parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, @@ -128,18 +119,12 @@ class ToolNode(Node[ToolNodeData]): node_data=self.node_data, for_log=True, ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, + message_stream = self._runtime.invoke( + tool_runtime=tool_runtime, tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + provider_name=self.node_data.provider_name, ) except ToolNodeError as e: yield StreamCompletedEvent( @@ -159,38 +144,16 @@ class ToolNode(Node[ToolNodeData]): messages=message_stream, tool_info=tool_info, parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, node_id=self._node_id, tool_runtime=tool_runtime, ) - except ToolInvokeError as e: + except ToolNodeError as e: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", + error=str(e), error_type=type(e).__name__, ) ) @@ -198,7 +161,7 @@ class ToolNode(Node[ToolNodeData]): def _generate_parameters( self, *, - tool_parameters: Sequence[ToolParameter], + tool_parameters: Sequence[ToolRuntimeParameter], variable_pool: "VariablePool", node_data: ToolNodeData, for_log: bool = False, @@ -207,7 +170,7 @@ class ToolNode(Node[ToolNodeData]): Generate parameters based on the given tool parameters, variable pool, and node data. Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. + tool_parameters (Sequence[ToolRuntimeParameter]): The list of tool parameters. variable_pool (VariablePool): The variable pool containing the variables. node_data (ToolNodeData): The data associated with the tool node. @@ -240,107 +203,89 @@ class ToolNode(Node[ToolNodeData]): return result - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - def _transform_message( self, - messages: Generator[ToolInvokeMessage, None, None], + messages: Generator[ToolRuntimeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, node_id: str, - tool_runtime: Tool, + tool_runtime: ToolRuntimeHandle, + **_: Any, ) -> Generator[NodeEventBase, None, LLMUsage]: """ - Convert ToolInvokeMessages into tuple[plain_text, files] + Convert graph-owned tool runtime messages into node outputs. """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - text = "" files: list[File] = [] json: list[dict | list] = [] variables: dict[str, Any] = {} - for message in message_stream: + for message in messages: if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, + ToolRuntimeMessage.MessageType.IMAGE_LINK, + ToolRuntimeMessage.MessageType.BINARY_LINK, + ToolRuntimeMessage.MessageType.IMAGE, }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool message is missing tool_file_id metadata") - tool_file_id = str(url).split("/")[-1].split(".")[0] - - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not found") + if tool_file.mime_type is None: + raise ToolFileError(f"tool file {tool_file_id} is missing mime type") - mapping = { + file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mime_type), "transfer_method": transfer_method, "url": url, } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) + file = self._runtime.build_file_reference(mapping=file_mapping) files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: + elif message.type == ToolRuntimeMessage.MessageType.BLOB: # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileError("tool blob message is missing tool_file_id metadata") + _stream, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) if not tool_file: raise ToolFileError(f"tool file {tool_file_id} not exists") - mapping = { + blob_file_mapping: dict[str, Any] = { "tool_file_id": tool_file_id, "transfer_method": FileTransferMethod.TOOL_FILE, } - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + files.append(self._runtime.build_file_reference(mapping=blob_file_mapping)) + elif message.type == ToolRuntimeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) text += message.message.text yield StreamChunkEvent( selector=[node_id, "text"], chunk=message.message.text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + elif message.type == ToolRuntimeMessage.MessageType.JSON: + assert isinstance(message.message, ToolRuntimeMessage.JsonMessage) # JSON message handling for tool node if message.message.json_object: json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) + elif message.type == ToolRuntimeMessage.MessageType.LINK: + assert isinstance(message.message, ToolRuntimeMessage.TextMessage) # Check if this LINK message is a file link file_obj = (message.meta or {}).get("file") @@ -356,8 +301,8 @@ class ToolNode(Node[ToolNodeData]): chunk=stream_text, is_final=False, ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + elif message.type == ToolRuntimeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolRuntimeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: @@ -374,7 +319,7 @@ class ToolNode(Node[ToolNodeData]): ) else: variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: + elif message.type == ToolRuntimeMessage.MessageType.FILE: assert message.meta is not None assert isinstance(message.meta, dict) # Validate that meta contains a 'file' key @@ -385,38 +330,16 @@ class ToolNode(Node[ToolNodeData]): if not isinstance(message.meta["file"], File): raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) + elif message.type == ToolRuntimeMessage.MessageType.LOG: + assert isinstance(message.message, ToolRuntimeMessage.LogMessage) if message.message.metadata: icon = tool_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - + icon, icon_dark = self._runtime.resolve_provider_icons( + provider_name=dict_metadata["provider"], + default_icon=icon, + ) dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata @@ -446,7 +369,7 @@ class ToolNode(Node[ToolNodeData]): is_final=True, ) - usage = self._extract_tool_usage(tool_runtime) + usage = self._runtime.get_usage(tool_runtime=tool_runtime) metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, @@ -468,21 +391,6 @@ class ToolNode(Node[ToolNodeData]): return usage - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, diff --git a/api/dify_graph/nodes/tool_runtime_entities.py b/api/dify_graph/nodes/tool_runtime_entities.py new file mode 100644 index 0000000000..5bb0c16573 --- /dev/null +++ b/api/dify_graph/nodes/tool_runtime_entities.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum, auto +from typing import Any + +from pydantic import BaseModel, ConfigDict, Field + + +class _ToolRuntimeModel(BaseModel): + model_config = ConfigDict(extra="forbid") + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeHandle: + """Opaque graph-owned handle for a workflow-layer tool runtime. + + Workflow-specific execution context must stay behind `raw` so the graph + contract does not absorb application-owned concepts. + """ + + raw: object + + +@dataclass(frozen=True, slots=True) +class ToolRuntimeParameter: + """Graph-owned parameter shape used by tool nodes.""" + + name: str + required: bool = False + + +class ToolRuntimeMessage(_ToolRuntimeModel): + """Graph-owned tool invocation message DTO.""" + + class TextMessage(_ToolRuntimeModel): + text: str + + class JsonMessage(_ToolRuntimeModel): + json_object: dict[str, Any] | list[Any] + suppress_output: bool = Field(default=False) + + class BlobMessage(_ToolRuntimeModel): + blob: bytes + + class BlobChunkMessage(_ToolRuntimeModel): + id: str + sequence: int + total_length: int + blob: bytes + end: bool + + class FileMessage(_ToolRuntimeModel): + file_marker: str = Field(default="file_marker") + + class VariableMessage(_ToolRuntimeModel): + variable_name: str + variable_value: dict[str, Any] | list[Any] | str | int | float | bool | None + stream: bool = Field(default=False) + + class LogMessage(_ToolRuntimeModel): + class LogStatus(StrEnum): + START = auto() + ERROR = auto() + SUCCESS = auto() + + id: str + label: str + parent_id: str | None = None + error: str | None = None + status: LogStatus + data: dict[str, Any] + metadata: dict[str, Any] = Field(default_factory=dict) + + class RetrieverResourceMessage(_ToolRuntimeModel): + retriever_resources: list[dict[str, Any]] + context: str + + class MessageType(StrEnum): + TEXT = auto() + IMAGE = auto() + LINK = auto() + BLOB = auto() + JSON = auto() + IMAGE_LINK = auto() + BINARY_LINK = auto() + VARIABLE = auto() + FILE = auto() + LOG = auto() + BLOB_CHUNK = auto() + RETRIEVER_RESOURCES = auto() + + type: MessageType = MessageType.TEXT + message: ( + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage + ) + meta: dict[str, Any] | None = None diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py index f9b261b191..813d82b61e 100644 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ b/api/dify_graph/nodes/variable_assigner/v1/node.py @@ -1,15 +1,14 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables import SegmentType, Variable, VariableBase from .node_data import VariableAssignerData, WriteMode @@ -56,18 +55,16 @@ class VariableAssignerNode(Node[VariableAssignerData]): node_data: VariableAssignerData, ) -> Mapping[str, Sequence[str]]: mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + selector_key = ".".join(node_data.assigned_variable_selector) + key = f"{node_id}.#{selector_key}#" + mapping[key] = node_data.assigned_variable_selector selector_key = ".".join(node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" mapping[key] = node_data.input_variable_selector return mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: assigned_variable_selector = self.node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) @@ -92,18 +89,18 @@ class VariableAssignerNode(Node[VariableAssignerData]): income_value = SegmentType.get_zero_value(original_variable.value_type) updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, + yield VariableUpdatedEvent(variable=cast(Variable, updated_variable)) + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={ + "value": income_value.to_object(), + }, + # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, + # we still set `output_variables` as a list to ensure the schema of output is + # compatible with `v2.VariableAssignerNode`. + process_data=common_helpers.set_updated_variables({}, updated_variables), + outputs={}, + ) ) diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py index f04a6b3b80..08d86613e6 100644 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ b/api/dify_graph/nodes/variable_assigner/v2/node.py @@ -1,15 +1,14 @@ import json -from collections.abc import Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any +from collections.abc import Generator, Mapping, MutableMapping, Sequence +from typing import TYPE_CHECKING, Any, cast -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID from dify_graph.entities.graph_config import NodeConfigDict from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult +from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent, VariableUpdatedEvent from dify_graph.nodes.base.node import Node from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase +from dify_graph.variables import SegmentType, Variable, VariableBase from dify_graph.variables.consts import SELECTORS_LENGTH from . import helpers @@ -29,9 +28,6 @@ if TYPE_CHECKING: def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return selector_str = ".".join(item.variable_selector) key = f"{node_id}.#{selector_str}#" mapping[key] = item.variable_selector @@ -103,15 +99,18 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): _source_mapping_from_item(var_mapping, node_id, item) return var_mapping - def _run(self) -> NodeRunResult: + def _run(self) -> Generator[NodeEventBase, None, None]: inputs = self.node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] + # Preserve intra-node read-after-write behavior without mutating the shared pool + # until the engine processes the emitted VariableUpdatedEvent instances. + working_variable_pool = self.graph_runtime_state.variable_pool.model_copy(deep=True) try: for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) + variable = working_variable_pool.get(item.variable_selector) # ==================== Validation Part @@ -136,60 +135,64 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) # Get value from variable pool + input_value = item.value if ( item.input_type == InputType.VARIABLE and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} and item.value is not None ): - value = self.graph_runtime_state.variable_pool.get(item.value) + value = working_variable_pool.get(item.value) if value is None: raise VariableNotFoundError(variable_selector=item.value) # Skip if value is NoneSegment if value.value_type == SegmentType.NONE: continue - item.value = value.value + input_value = value.value # If set string / bytes / bytearray to object, try convert string to object. if ( item.operation == Operation.SET and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) + and isinstance(input_value, str | bytes | bytearray) ): try: - item.value = json.loads(item.value) + input_value = json.loads(input_value) except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # Check if input value is valid if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value + variable_type=variable.value_type, operation=item.operation, value=input_value ): - raise InvalidInputValueError(value=item.value) + raise InvalidInputValueError(value=input_value) # ==================== Execution Part updated_value = self._handle_item( variable=variable, operation=item.operation, - value=item.value, + value=input_value, ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) + updated_variable = variable.model_copy(update={"value": updated_value}) + working_variable_pool.add(updated_variable.selector, updated_variable) + updated_variable_selectors.append(updated_variable.selector) except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + error=str(e), + ) ) + return # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) + # remove duplicated items while preserving the first update order. + updated_variable_selectors = list(dict.fromkeys(map(tuple, updated_variable_selectors))) for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) + variable = working_variable_pool.get(selector) if not isinstance(variable, VariableBase): raise VariableNotFoundError(variable_selector=selector) process_data[variable.name] = variable.value @@ -197,15 +200,23 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]): updated_variables = [ common_helpers.variable_to_processed_data(selector, seg) for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None + if (seg := working_variable_pool.get(selector)) is not None ] process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, + for selector in updated_variable_selectors: + variable = working_variable_pool.get(selector) + if not isinstance(variable, VariableBase): + raise VariableNotFoundError(variable_selector=selector) + yield VariableUpdatedEvent(variable=cast(Variable, variable)) + + yield StreamCompletedEvent( + node_run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + outputs={}, + ) ) def _handle_item( diff --git a/api/dify_graph/prompt_entities.py b/api/dify_graph/prompt_entities.py new file mode 100644 index 0000000000..877b530613 --- /dev/null +++ b/api/dify_graph/prompt_entities.py @@ -0,0 +1,47 @@ +from typing import Literal + +from pydantic import BaseModel + +from dify_graph.model_runtime.entities.message_entities import PromptMessageRole + + +class ChatModelMessage(BaseModel): + """Graph-owned chat prompt template message.""" + + text: str + role: PromptMessageRole + edition_type: Literal["basic", "jinja2"] | None = None + + +class CompletionModelPromptTemplate(BaseModel): + """Graph-owned completion prompt template.""" + + text: str + edition_type: Literal["basic", "jinja2"] | None = None + + +class MemoryConfig(BaseModel): + """Graph-owned memory configuration for prompt assembly.""" + + class RolePrefix(BaseModel): + """Role labels used when serializing completion-model histories.""" + + user: str + assistant: str + + class WindowConfig(BaseModel): + """History windowing controls.""" + + enabled: bool + size: int | None = None + + role_prefix: RolePrefix | None = None + window: WindowConfig + query_prompt_template: str | None = None + + +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc..0000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index 88966831cb..0000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f07649..0000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497..0000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py index 41acc6db35..df65bbaeef 100644 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ b/api/dify_graph/runtime/graph_runtime_state.py @@ -3,6 +3,7 @@ from __future__ import annotations import importlib import json from collections.abc import Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from copy import deepcopy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, ClassVar, Protocol @@ -142,10 +143,9 @@ class ChildGraphEngineBuilderProtocol(Protocol): *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: ... @@ -211,6 +211,7 @@ class GraphRuntimeState: graph_execution: GraphExecutionProtocol | None = None, response_coordinator: ResponseStreamCoordinatorProtocol | None = None, graph: GraphProtocol | None = None, + execution_context: AbstractContextManager[object] | None = None, ) -> None: self._variable_pool = variable_pool self._start_at = start_at @@ -231,6 +232,9 @@ class GraphRuntimeState: self._ready_queue = ready_queue self._graph_execution = graph_execution self._response_coordinator = response_coordinator + # Application code injects this when worker threads must restore request + # or framework-local state. It is intentionally excluded from snapshots. + self._execution_context = execution_context if execution_context is not None else nullcontext(None) self._pending_response_coordinator_dump: str | None = None self._pending_graph_execution_workflow_id: str | None = None self._paused_nodes: set[str] = set() @@ -285,21 +289,19 @@ class GraphRuntimeState: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> Any: + """Create a child graph engine that derives its runtime state from the parent.""" if self._child_engine_builder is None: raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") return self._child_engine_builder.build_child_engine( workflow_id=workflow_id, graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, + parent_graph_runtime_state=self, root_node_id=root_node_id, - layers=layers, + variable_pool=variable_pool, ) # ------------------------------------------------------------------ @@ -329,6 +331,14 @@ class GraphRuntimeState: self._response_coordinator = self._build_response_coordinator(self._graph) return self._response_coordinator + @property + def execution_context(self) -> AbstractContextManager[object]: + return self._execution_context + + @execution_context.setter + def execution_context(self, value: AbstractContextManager[object] | None) -> None: + self._execution_context = value if value is not None else nullcontext(None) + # ------------------------------------------------------------------ # Scalar state # ------------------------------------------------------------------ diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py index 7e55ece3f1..b9674edfed 100644 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ b/api/dify_graph/runtime/graph_runtime_state_protocol.py @@ -2,7 +2,6 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView from dify_graph.variables.segments import Segment @@ -31,9 +30,6 @@ class ReadOnlyGraphRuntimeState(Protocol): All methods return defensive copies to ensure immutability. """ - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - @property def variable_pool(self) -> ReadOnlyVariablePool: """Get read-only access to the variable pool.""" diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py index ca06d88c3d..66252ca3fb 100644 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ b/api/dify_graph/runtime/read_only_wrappers.py @@ -5,7 +5,6 @@ from copy import deepcopy from typing import Any from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView from dify_graph.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState @@ -43,10 +42,6 @@ class ReadOnlyGraphRuntimeStateWrapper: self._state = state self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - @property def variable_pool(self) -> ReadOnlyVariablePoolWrapper: return self._variable_pool_wrapper diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py index e3ef6a2897..fe4aeed0fa 100644 --- a/api/dify_graph/runtime/variable_pool.py +++ b/api/dify_graph/runtime/variable_pool.py @@ -6,84 +6,84 @@ from collections.abc import Mapping, Sequence from copy import deepcopy from typing import Annotated, Any, Union, cast -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, model_validator -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import Segment, SegmentGroup, VariableBase +from dify_graph.variables import Segment, SegmentGroup, VariableBase, build_segment, segment_to_variable from dify_graph.variables.consts import SELECTORS_LENGTH from dify_graph.variables.segments import FileSegment, ObjectSegment from dify_graph.variables.variables import RAGPipelineVariableInput, Variable -from factories import variable_factory VariableValue = Union[str, int, float, dict[str, object], list[object], File] VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") +def _default_variable_dictionary() -> defaultdict[str, dict[str, Variable]]: + return defaultdict(dict) + + class VariablePool(BaseModel): + _SYSTEM_VARIABLE_NODE_ID = "sys" + _ENVIRONMENT_VARIABLE_NODE_ID = "env" + _CONVERSATION_VARIABLE_NODE_ID = "conversation" + _RAG_PIPELINE_VARIABLE_NODE_ID = "rag" + # Variable dictionary is a dictionary for looking up variables by their selector. # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( description="Variables mapping", - default=defaultdict(dict), + default_factory=_default_variable_dictionary, ) + system_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + environment_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + conversation_variables: Sequence[Variable] = Field(default_factory=tuple, exclude=True) + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = Field(default_factory=tuple, exclude=True) + user_inputs: Mapping[str, Any] = Field(default_factory=dict, exclude=True) - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) + @model_validator(mode="after") + def _load_legacy_bootstrap_inputs(self) -> VariablePool: + """ + Accept legacy constructor kwargs that still appear throughout the workflow + layer while keeping serialized state focused on `variable_dictionary`. + """ - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) + self._ingest_legacy_variables(self.system_variables, node_id=self._SYSTEM_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.environment_variables, node_id=self._ENVIRONMENT_VARIABLE_NODE_ID) + self._ingest_legacy_variables(self.conversation_variables, node_id=self._CONVERSATION_VARIABLE_NODE_ID) + self._ingest_legacy_rag_variables(self.rag_pipeline_variables) + + # These kwargs are accepted for compatibility but should not affect the + # stable serialized form or model equality. + self.system_variables = () + self.environment_variables = () + self.conversation_variables = () + self.rag_pipeline_variables = () + self.user_inputs = {} + return self + + def _ingest_legacy_variables(self, variables: Sequence[Variable], *, node_id: str) -> None: + for variable in variables: + selector = [node_id, variable.name] + normalized_variable = variable + if list(variable.selector) != selector: + normalized_variable = variable.model_copy(update={"selector": selector}) + self.add(normalized_variable.selector, normalized_variable) + + def _ingest_legacy_rag_variables(self, rag_pipeline_variables: Sequence[RAGPipelineVariableInput]) -> None: + if not rag_pipeline_variables: + return + + values_by_node_id: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_variable_input in rag_pipeline_variables: + values_by_node_id[rag_variable_input.variable.belong_to_node_id][rag_variable_input.variable.variable] = ( + rag_variable_input.value + ) + + for node_id, value in values_by_node_id.items(): + self.add((self._RAG_PIPELINE_VARIABLE_NODE_ID, node_id), value) def add(self, selector: Sequence[str], value: Any, /): """ @@ -114,10 +114,10 @@ class VariablePool(BaseModel): if isinstance(value, VariableBase): variable = value elif isinstance(value, Segment): - variable = variable_factory.segment_to_variable(segment=value, selector=selector) + variable = segment_to_variable(segment=value, selector=selector) else: - segment = variable_factory.build_segment(value) - variable = variable_factory.segment_to_variable(segment=segment, selector=selector) + segment = build_segment(value) + variable = segment_to_variable(segment=segment, selector=selector) node_id, name = self._selector_to_keys(selector) # Based on the definition of `Variable`, @@ -180,7 +180,7 @@ class VariablePool(BaseModel): return None attr = FileAttribute(attr) attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return variable_factory.build_segment(attr_value) + return build_segment(attr_value) # Navigate through nested attributes result: Any = segment @@ -191,7 +191,7 @@ class VariablePool(BaseModel): return None # Return result as Segment - return result if isinstance(result, Segment) else variable_factory.build_segment(result) + return result if isinstance(result, Segment) else build_segment(result) def _extract_value(self, obj: Any): """Extract the actual value from an ObjectSegment.""" @@ -212,7 +212,7 @@ class VariablePool(BaseModel): """ if not isinstance(obj, dict) or attr not in obj: return None - return variable_factory.build_segment(obj.get(attr)) + return build_segment(obj.get(attr)) def remove(self, selector: Sequence[str], /): """ @@ -239,7 +239,7 @@ class VariablePool(BaseModel): if "." in part and (variable := self.get(part.split("."))): segments.append(variable) else: - segments.append(variable_factory.build_segment(part)) + segments.append(build_segment(part)) return SegmentGroup(value=segments) def get_file(self, selector: Sequence[str], /) -> FileSegment | None: @@ -262,19 +262,18 @@ class VariablePool(BaseModel): return result - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) + def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, object]: + """Return a selector-style snapshot of the entire variable pool.""" + + result: dict[str, object] = {} + for node_id, variables in self.variable_dictionary.items(): + for name, variable in variables.items(): + output_name = name if node_id == unprefixed_node_id else f"{node_id}.{name}" + result[output_name] = deepcopy(variable.value) + + return result @classmethod def empty(cls) -> VariablePool: """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) + return cls() diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892..0000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/dify_graph/template_rendering.py b/api/dify_graph/template_rendering.py new file mode 100644 index 0000000000..0527e58f6d --- /dev/null +++ b/api/dify_graph/template_rendering.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Mapping +from typing import Any + + +class TemplateRenderError(ValueError): + """Raised when rendering a template fails.""" + + +class Jinja2TemplateRenderer(ABC): + """Nominal renderer contract for Jinja2 template rendering in graph nodes.""" + + @abstractmethod + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + """Render the template into plain text.""" + raise NotImplementedError diff --git a/api/dify_graph/utils/json_in_md_parser.py b/api/dify_graph/utils/json_in_md_parser.py new file mode 100644 index 0000000000..4416b4774b --- /dev/null +++ b/api/dify_graph/utils/json_in_md_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import json + + +class OutputParserError(ValueError): + """Raised when a markdown-wrapped JSON payload cannot be parsed or validated.""" + + +def parse_json_markdown(json_string: str) -> dict | list: + """Extract and parse the first JSON object or array embedded in markdown text.""" + json_string = json_string.strip() + starts = ["```json", "```", "``", "`", "{", "["] + ends = ["```", "``", "`", "}", "]"] + end_index = -1 + start_index = 0 + + for start_marker in starts: + start_index = json_string.find(start_marker) + if start_index != -1: + if json_string[start_index] not in ("{", "["): + start_index += len(start_marker) + break + + if start_index != -1: + for end_marker in ends: + end_index = json_string.rfind(end_marker, start_index) + if end_index != -1: + if json_string[end_index] in ("}", "]"): + end_index += 1 + break + + if start_index == -1 or end_index == -1 or start_index >= end_index: + raise ValueError("could not find json block in the output.") + + extracted_content = json_string[start_index:end_index].strip() + return json.loads(extracted_content) + + +def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict: + try: + json_obj = parse_json_markdown(text) + except json.JSONDecodeError as exc: + raise OutputParserError(f"got invalid json object. error: {exc}") from exc + + if isinstance(json_obj, list): + if len(json_obj) == 1 and isinstance(json_obj[0], dict): + json_obj = json_obj[0] + else: + raise OutputParserError(f"got invalid return object. obj:{json_obj}") + + for key in expected_keys: + if key not in json_obj: + raise OutputParserError( + f"got invalid return object. expected key `{key}` to be present, but got {json_obj}" + ) + + return json_obj diff --git a/api/dify_graph/variable_loader.py b/api/dify_graph/variable_loader.py index d263450334..707fb9fb7d 100644 --- a/api/dify_graph/variable_loader.py +++ b/api/dify_graph/variable_loader.py @@ -13,14 +13,6 @@ class VariableLoader(Protocol): A `VariableLoader` is responsible for retrieving additional variables required during the execution of a single node, which are not provided as user inputs. - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into `WorkflowService.single_step_run`, we may get rid of this interface. """ diff --git a/api/dify_graph/variables/__init__.py b/api/dify_graph/variables/__init__.py index be3fc8d97a..e9beb6cb95 100644 --- a/api/dify_graph/variables/__init__.py +++ b/api/dify_graph/variables/__init__.py @@ -1,3 +1,10 @@ +from .factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, +) from .input_entities import VariableEntity, VariableEntityType from .segment_group import SegmentGroup from .segments import ( @@ -63,8 +70,13 @@ __all__ = [ "SegmentType", "StringSegment", "StringVariable", + "TypeMismatchError", + "UnsupportedSegmentTypeError", "Variable", "VariableBase", "VariableEntity", "VariableEntityType", + "build_segment", + "build_segment_with_type", + "segment_to_variable", ] diff --git a/api/dify_graph/variables/factory.py b/api/dify_graph/variables/factory.py new file mode 100644 index 0000000000..25d48d75b0 --- /dev/null +++ b/api/dify_graph/variables/factory.py @@ -0,0 +1,202 @@ +"""Graph-owned helpers for converting runtime values, segments, and variables. + +These conversions are part of the `dify_graph` runtime model and must stay +independent from top-level API factory modules so graph nodes and state +containers can operate without importing application-layer packages. +""" + +from collections.abc import Mapping, Sequence +from typing import Any, cast +from uuid import uuid4 + +from dify_graph.file import File + +from .segments import ( + ArrayAnySegment, + ArrayBooleanSegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArraySegment, + ArrayStringSegment, + BooleanSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + StringSegment, +) +from .types import SegmentType +from .variables import ( + ArrayAnyVariable, + ArrayBooleanVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + BooleanVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + StringVariable, + VariableBase, +) + + +class UnsupportedSegmentTypeError(Exception): + pass + + +class TypeMismatchError(Exception): + pass + + +SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[Any]] = { + ArrayAnySegment: ArrayAnyVariable, + ArrayBooleanSegment: ArrayBooleanVariable, + ArrayFileSegment: ArrayFileVariable, + ArrayNumberSegment: ArrayNumberVariable, + ArrayObjectSegment: ArrayObjectVariable, + ArrayStringSegment: ArrayStringVariable, + BooleanSegment: BooleanVariable, + FileSegment: FileVariable, + FloatSegment: FloatVariable, + IntegerSegment: IntegerVariable, + NoneSegment: NoneVariable, + ObjectSegment: ObjectVariable, + StringSegment: StringVariable, +} + + +def build_segment(value: Any, /) -> Segment: + """Build a runtime segment from a Python value.""" + if value is None: + return NoneSegment() + if isinstance(value, Segment): + return value + if isinstance(value, str): + return StringSegment(value=value) + if isinstance(value, bool): + return BooleanSegment(value=value) + if isinstance(value, int): + return IntegerSegment(value=value) + if isinstance(value, float): + return FloatSegment(value=value) + if isinstance(value, dict): + return ObjectSegment(value=value) + if isinstance(value, File): + return FileSegment(value=value) + if isinstance(value, list): + items = [build_segment(item) for item in value] + types = {item.value_type for item in items} + if all(isinstance(item, ArraySegment) for item in items): + return ArrayAnySegment(value=value) + if len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + + match types.pop(): + case SegmentType.STRING: + return ArrayStringSegment(value=value) + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return ArrayNumberSegment(value=value) + case SegmentType.BOOLEAN: + return ArrayBooleanSegment(value=value) + case SegmentType.OBJECT: + return ArrayObjectSegment(value=value) + case SegmentType.FILE: + return ArrayFileSegment(value=value) + case SegmentType.NONE: + return ArrayAnySegment(value=value) + case _: + raise ValueError(f"not supported value {value}") + raise ValueError(f"not supported value {value}") + + +_SEGMENT_FACTORY: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.BOOLEAN: BooleanSegment, + SegmentType.OBJECT: ObjectSegment, + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, + SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, +} + + +def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: + """Build a segment while enforcing compatibility with the expected runtime type.""" + if value is None: + if segment_type == SegmentType.NONE: + return NoneSegment() + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") + + if isinstance(value, list) and len(value) == 0: + if segment_type == SegmentType.ARRAY_ANY: + return ArrayAnySegment(value=value) + if segment_type == SegmentType.ARRAY_STRING: + return ArrayStringSegment(value=value) + if segment_type == SegmentType.ARRAY_BOOLEAN: + return ArrayBooleanSegment(value=value) + if segment_type == SegmentType.ARRAY_NUMBER: + return ArrayNumberSegment(value=value) + if segment_type == SegmentType.ARRAY_OBJECT: + return ArrayObjectSegment(value=value) + if segment_type == SegmentType.ARRAY_FILE: + return ArrayFileSegment(value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + + inferred_type = SegmentType.infer_segment_type(value) + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) + if inferred_type == segment_type: + segment_class = _SEGMENT_FACTORY[segment_type] + return segment_class(value_type=segment_type, value=value) + if segment_type == SegmentType.NUMBER and inferred_type in (SegmentType.INTEGER, SegmentType.FLOAT): + segment_class = _SEGMENT_FACTORY[inferred_type] + return segment_class(value_type=inferred_type, value=value) + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") + + +def segment_to_variable( + *, + segment: Segment, + selector: Sequence[str], + id: str | None = None, + name: str | None = None, + description: str = "", +) -> VariableBase: + """Convert a runtime segment into a runtime variable for storage in the pool.""" + if isinstance(segment, VariableBase): + return segment + name = name or selector[-1] + id = id or str(uuid4()) + + segment_type = type(segment) + if segment_type not in SEGMENT_TO_VARIABLE_MAP: + raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") + + variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] + return cast( + VariableBase, + variable_class( + id=id, + name=name, + description=description, + value=segment.value, + selector=list(selector), + ), + ) diff --git a/api/dify_graph/variables/types.py b/api/dify_graph/variables/types.py index 53bf495a27..bb249b4498 100644 --- a/api/dify_graph/variables/types.py +++ b/api/dify_graph/variables/types.py @@ -220,8 +220,8 @@ class SegmentType(StrEnum): @staticmethod def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency - from factories import variable_factory + # Lazy import to avoid circular dependency between segment types and factory helpers. + from dify_graph.variables.factory import build_segment, build_segment_with_type match t: case ( @@ -231,19 +231,19 @@ class SegmentType(StrEnum): | SegmentType.ARRAY_NUMBER | SegmentType.ARRAY_BOOLEAN ): - return variable_factory.build_segment_with_type(t, []) + return build_segment_with_type(t, []) case SegmentType.OBJECT: - return variable_factory.build_segment({}) + return build_segment({}) case SegmentType.STRING: - return variable_factory.build_segment("") + return build_segment("") case SegmentType.INTEGER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.FLOAT: - return variable_factory.build_segment(0.0) + return build_segment(0.0) case SegmentType.NUMBER: - return variable_factory.build_segment(0) + return build_segment(0) case SegmentType.BOOLEAN: - return variable_factory.build_segment(False) + return build_segment(False) case _: raise ValueError(f"unsupported variable type: {t}") diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4..bd5bc08bff 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,5 +1,6 @@ import logging +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from dify_graph.nodes import BuiltinNodeTypes @@ -19,8 +20,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +32,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 7b6a73af52..367a4c1ede 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -2,7 +2,7 @@ import ssl from datetime import timedelta from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb..c0dc3c1a0a 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -7,9 +7,9 @@ from typing import Union from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432..c38f9950f6 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -17,10 +17,10 @@ from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.entities import WorkflowNodeExecution from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py deleted file mode 100644 index cb07ba58ae..0000000000 --- a/api/factories/file_factory.py +++ /dev/null @@ -1,618 +0,0 @@ -import logging -import mimetypes -import os -import re -import urllib.parse -import uuid -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import httpx -from sqlalchemy import select -from sqlalchemy.orm import Session -from werkzeug.http import parse_options_header - -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.helper import ssrf_proxy -from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers -from extensions.ext_database import db -from models import MessageFile, ToolFile, UploadFile - -logger = logging.getLogger(__name__) - - -def build_from_message_files( - *, - message_files: Sequence["MessageFile"], - tenant_id: str, - config: FileUploadConfig | None = None, -) -> Sequence[File]: - results = [ - build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) - for file in message_files - if file.belongs_to != FileBelongsTo.ASSISTANT - ] - return results - - -def build_from_message_file( - *, - message_file: "MessageFile", - tenant_id: str, - config: FileUploadConfig | None, -): - mapping = { - "transfer_method": message_file.transfer_method, - "url": message_file.url, - "type": message_file.type, - } - - # Only include id if it exists (message_file has been committed to DB) - if message_file.id: - mapping["id"] = message_file.id - - # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE: - mapping["tool_file_id"] = message_file.upload_file_id - else: - mapping["upload_file_id"] = message_file.upload_file_id - - return build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - - -def build_from_mapping( - *, - mapping: Mapping[str, Any], - tenant_id: str, - config: FileUploadConfig | None = None, - strict_type_validation: bool = False, -) -> File: - transfer_method_value = mapping.get("transfer_method") - if not transfer_method_value: - raise ValueError("transfer_method is required in file mapping") - transfer_method = FileTransferMethod.value_of(transfer_method_value) - - build_functions: dict[FileTransferMethod, Callable] = { - FileTransferMethod.LOCAL_FILE: _build_from_local_file, - FileTransferMethod.REMOTE_URL: _build_from_remote_url, - FileTransferMethod.TOOL_FILE: _build_from_tool_file, - FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, - } - - build_func = build_functions.get(transfer_method) - if not build_func: - raise ValueError(f"Invalid file transfer method: {transfer_method}") - - file: File = build_func( - mapping=mapping, - tenant_id=tenant_id, - transfer_method=transfer_method, - strict_type_validation=strict_type_validation, - ) - - if config and not _is_file_valid_with_config( - input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension or "", - file_transfer_method=file.transfer_method, - config=config, - ): - raise ValueError(f"File validation failed for file: {file.filename}") - - return file - - -def build_from_mappings( - *, - mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None = None, - tenant_id: str, - strict_type_validation: bool = False, -) -> Sequence[File]: - # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. - # Implement batch processing to reduce database load when handling multiple files. - # Filter out None/empty mappings to avoid errors - def is_valid_mapping(m: Mapping[str, Any]) -> bool: - if not m or not m.get("transfer_method"): - return False - # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None - transfer_method = m.get("transfer_method") - if transfer_method == FileTransferMethod.REMOTE_URL: - url = m.get("url") or m.get("remote_url") - if not url: - return False - return True - - valid_mappings = [m for m in mappings if is_valid_mapping(m)] - files = [ - build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - strict_type_validation=strict_type_validation, - ) - for mapping in valid_mappings - ] - - if ( - config - # If image config is set. - and config.image_config - # And the number of image files exceeds the maximum limit - and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits - ): - raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config and config.number_limits and len(files) > config.number_limits: - raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - - return files - - -def _build_from_local_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if not upload_file_id: - raise ValueError("Invalid upload file id") - # check if upload_file_id is a valid uuid - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - row = db.session.scalar(stmt) - if row is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - specified_type = mapping.get("type", "custom") - - if strict_type_validation and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), - size=row.size, - storage_key=row.key, - ) - - -def _build_from_remote_url( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if upload_file_id: - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - upload_file = db.session.scalar(stmt) - if upload_file is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type( - extension="." + upload_file.extension, mime_type=upload_file.mime_type - ) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), - size=upload_file.size, - storage_key=upload_file.key, - ) - url = mapping.get("url") or mapping.get("remote_url") - if not url: - raise ValueError("Invalid file url") - - mime_type, filename, file_size = _get_remote_file_info(url) - extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=url, - mime_type=mime_type, - extension=extension, - size=file_size, - storage_key="", - ) - - -def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: - filename: str | None = None - # Try to extract from Content-Disposition header first - if content_disposition: - # Manually extract filename* parameter since parse_options_header doesn't support it - filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) - if filename_star_match: - raw_star = filename_star_match.group(1).strip() - # Remove trailing quotes if present - raw_star = raw_star.removesuffix('"') - # format: charset'lang'value - try: - parts = raw_star.split("'", 2) - charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" - value = parts[2] if len(parts) == 3 else parts[-1] - filename = urllib.parse.unquote(value, encoding=charset, errors="replace") - except Exception: - # Fallback: try to extract value after the last single quote - if "''" in raw_star: - filename = urllib.parse.unquote(raw_star.split("''")[-1]) - else: - filename = urllib.parse.unquote(raw_star) - - if not filename: - # Fallback to regular filename parameter - _, params = parse_options_header(content_disposition) - raw = params.get("filename") - if raw: - # Strip surrounding quotes and percent-decode if present - if len(raw) >= 2 and raw[0] == raw[-1] == '"': - raw = raw[1:-1] - filename = urllib.parse.unquote(raw) - # Fallback to URL path if no filename from header - if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None - # Defense-in-depth: ensure basename only - if filename: - filename = os.path.basename(filename) - # Return None if filename is empty or only whitespace - if not filename or not filename.strip(): - filename = None - return filename or None - - -def _guess_mime_type(filename: str) -> str: - """Guess MIME type from filename, returning empty string if None.""" - guessed_mime, _ = mimetypes.guess_type(filename) - return guessed_mime or "" - - -def _get_remote_file_info(url: str): - file_size = -1 - parsed_url = urllib.parse.urlparse(url) - url_path = parsed_url.path - filename = os.path.basename(url_path) - - # Initialize mime_type from filename as fallback - mime_type = _guess_mime_type(filename) - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - content_disposition = resp.headers.get("Content-Disposition") - extracted_filename = _extract_filename(url_path, content_disposition) - if extracted_filename: - filename = extracted_filename - mime_type = _guess_mime_type(filename) - file_size = int(resp.headers.get("Content-Length", file_size)) - # Fallback to Content-Type header if mime_type is still empty - if not mime_type: - mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() - - if not filename: - extension = mimetypes.guess_extension(mime_type) or ".bin" - filename = f"{uuid.uuid4().hex}{extension}" - if not mime_type: - mime_type = _guess_mime_type(filename) - - return mime_type, filename, file_size - - -def _build_from_tool_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") - - if not tool_file_id: - raise ValueError(f"ToolFile {tool_file_id} not found") - tool_file = db.session.scalar( - select(ToolFile).where( - ToolFile.id == tool_file_id, - ToolFile.tenant_id == tenant_id, - ) - ) - - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") - - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - tenant_id=tenant_id, - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) - - -def _build_from_datasource_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - datasource_file_id = mapping.get("datasource_file_id") - if not datasource_file_id: - raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = db.session.scalar( - select(UploadFile).where( - UploadFile.id == datasource_file_id, - UploadFile.tenant_id == tenant_id, - ) - ) - - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - related_id=datasource_file.id, - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) - - -def _is_file_valid_with_config( - *, - input_file_type: str, - file_extension: str, - file_transfer_method: FileTransferMethod, - config: FileUploadConfig, -) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions - if file_transfer_method == FileTransferMethod.TOOL_FILE: - return True - - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): - return False - - if ( - input_file_type == FileType.CUSTOM - and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions - ): - return False - - if input_file_type == FileType.IMAGE: - if ( - config.image_config - and config.image_config.transfer_methods - and file_transfer_method not in config.image_config.transfer_methods - ): - return False - elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: - return False - - return True - - -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - -class StorageKeyLoader: - """FileKeyLoader load the storage key from database for a list of files. - This loader is batched, the database query count is constant regardless of the input size. - """ - - def __init__(self, session: Session, tenant_id: str): - self._session = session - self._tenant_id = tenant_id - - def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: - stmt = select(UploadFile).where( - UploadFile.id.in_(upload_file_ids), - UploadFile.tenant_id == self._tenant_id, - ) - - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: - stmt = select(ToolFile).where( - ToolFile.id.in_(tool_file_ids), - ToolFile.tenant_id == self._tenant_id, - ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def load_storage_keys(self, files: Sequence[File]): - """Loads storage keys for a sequence of files by retrieving the corresponding - `UploadFile` or `ToolFile` records from the database based on their transfer method. - - This method doesn't modify the input sequence structure but updates the `_storage_key` - property of each file object by extracting the relevant key from its database record. - - Performance note: This is a batched operation where database query count remains constant - regardless of input size. However, for optimal performance, input sequences should contain - fewer than 1000 files. For larger collections, split into smaller batches and process each - batch separately. - """ - - upload_file_ids: list[uuid.UUID] = [] - tool_file_ids: list[uuid.UUID] = [] - for file in files: - related_model_id = file.related_id - if file.related_id is None: - raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) - - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(model_id) - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(model_id) - - tool_files = self._load_tool_files(tool_file_ids) - upload_files = self._load_upload_files(upload_file_ids) - for file in files: - model_id = uuid.UUID(file.related_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(model_id) - if upload_file_row is None: - raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(model_id) - if tool_file_row is None: - raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/__init__.py b/api/factories/file_factory/__init__.py new file mode 100644 index 0000000000..ae0cd972ec --- /dev/null +++ b/api/factories/file_factory/__init__.py @@ -0,0 +1,18 @@ +"""Workflow file factory package. + +This package normalizes workflow-layer file payloads into graph-layer ``File`` +values. It keeps tenancy and ownership checks in the application layer and +exports the workflow-facing file builders for callers. +""" + +from .builders import build_from_mapping, build_from_mappings +from .message_files import build_from_message_file, build_from_message_files +from .storage_keys import StorageKeyLoader + +__all__ = [ + "StorageKeyLoader", + "build_from_mapping", + "build_from_mappings", + "build_from_message_file", + "build_from_message_files", +] diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py new file mode 100644 index 0000000000..d2c60aebb7 --- /dev/null +++ b/api/factories/file_factory/builders.py @@ -0,0 +1,329 @@ +"""Core builders for workflow file mappings.""" + +from __future__ import annotations + +import mimetypes +import uuid +from collections.abc import Mapping, Sequence +from typing import Any + +from sqlalchemy import select + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference +from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers +from dify_graph.file.file_factory import standardize_file_type +from extensions.ext_database import db +from models import ToolFile, UploadFile + +from .common import resolve_mapping_file_id +from .remote import get_remote_file_info +from .validation import is_file_valid_with_config + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + transfer_method_value = mapping.get("transfer_method") + if not transfer_method_value: + raise ValueError("transfer_method is required in file mapping") + + transfer_method = FileTransferMethod.value_of(transfer_method_value) + build_func = _get_build_function(transfer_method) + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + + if config and not is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. + valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)] + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + for mapping in valid_mappings + ] + + if ( + config + and config.image_config + and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _get_build_function(transfer_method: FileTransferMethod): + build_functions = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, + } + build_func = build_functions.get(transfer_method) + if build_func is None: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + return build_func + + +def _resolve_file_type( + *, + detected_file_type: FileType, + specified_type: str | None, + strict_type_validation: bool, +) -> FileType: + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + if specified_type and specified_type != "custom": + return FileType(specified_type) + return detected_file_type + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) + + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=filename, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id") + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") + + stmt = select(ToolFile).where( + ToolFile.id == tool_file_id, + ToolFile.tenant_id == tenant_id, + ) + tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") + + stmt = select(UploadFile).where( + UploadFile.id == datasource_file_id, + UploadFile.tenant_id == tenant_id, + ) + datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) + + +def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: + if not mapping or not mapping.get("transfer_method"): + return False + + if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + return False + + return True diff --git a/api/factories/file_factory/common.py b/api/factories/file_factory/common.py new file mode 100644 index 0000000000..2e1c95ab3f --- /dev/null +++ b/api/factories/file_factory/common.py @@ -0,0 +1,27 @@ +"""Shared helpers for workflow file factory modules.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.workflow.file_reference import resolve_file_record_id + + +def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + """Resolve historical file identifiers from persisted mapping payloads. + + Workflow and model payloads can outlive file schema changes. Older rows may + still carry concrete identifiers in legacy fields such as ``related_id``, + while newer payloads use opaque references. Keep this compatibility lookup in + the factory layer so historical data remains readable without reintroducing + storage details into graph-layer ``File`` values. + """ + + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py new file mode 100644 index 0000000000..0978d336e8 --- /dev/null +++ b/api/factories/file_factory/message_files.py @@ -0,0 +1,59 @@ +"""Adapters from persisted message files to graph-layer file values.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from core.app.file_access import FileAccessControllerProtocol +from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig +from models import MessageFile + +from .builders import build_from_mapping + + +def build_from_message_files( + *, + message_files: Sequence[MessageFile], + tenant_id: str, + config: FileUploadConfig | None = None, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + return [ + build_from_message_file( + message_file=message_file, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) + for message_file in message_files + if message_file.belongs_to != FileBelongsTo.ASSISTANT + ] + + +def build_from_message_file( + *, + message_file: MessageFile, + tenant_id: str, + config: FileUploadConfig | None, + access_controller: FileAccessControllerProtocol, +) -> File: + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "type": message_file.type, + } + + if message_file.id: + mapping["id"] = message_file.id + + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = message_file.upload_file_id + else: + mapping["upload_file_id"] = message_file.upload_file_id + + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py new file mode 100644 index 0000000000..e5a7186007 --- /dev/null +++ b/api/factories/file_factory/remote.py @@ -0,0 +1,91 @@ +"""Remote file metadata helpers used by workflow file normalization. + +These helpers are part of the ``factories.file_factory`` package surface +because both workflow builders and tests rely on the same RFC5987 filename +parsing and HEAD-response normalization rules. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +import urllib.parse +import uuid + +import httpx +from werkzeug.http import parse_options_header + +from core.helper import ssrf_proxy + + +def extract_filename(url_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path.""" + filename: str | None = None + if content_disposition: + filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) + if filename_star_match: + raw_star = filename_star_match.group(1).strip() + raw_star = raw_star.removesuffix('"') + try: + parts = raw_star.split("'", 2) + charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" + value = parts[2] if len(parts) == 3 else parts[-1] + filename = urllib.parse.unquote(value, encoding=charset, errors="replace") + except Exception: + if "''" in raw_star: + filename = urllib.parse.unquote(raw_star.split("''")[-1]) + else: + filename = urllib.parse.unquote(raw_star) + + if not filename: + _, params = parse_options_header(content_disposition) + raw = params.get("filename") + if raw: + if len(raw) >= 2 and raw[0] == raw[-1] == '"': + raw = raw[1:-1] + filename = urllib.parse.unquote(raw) + + if not filename: + candidate = os.path.basename(url_path) + filename = urllib.parse.unquote(candidate) if candidate else None + + if filename: + filename = os.path.basename(filename) + if not filename or not filename.strip(): + filename = None + + return filename or None + + +def _guess_mime_type(filename: str) -> str: + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + +def get_remote_file_info(url: str) -> tuple[str, str, int]: + """Resolve remote file metadata with SSRF-safe HEAD probing.""" + file_size = -1 + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + mime_type = _guess_mime_type(filename) + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) + file_size = int(resp.headers.get("Content-Length", file_size)) + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + + return mime_type, filename, file_size diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py new file mode 100644 index 0000000000..17edf54278 --- /dev/null +++ b/api/factories/file_factory/storage_keys.py @@ -0,0 +1,106 @@ +"""Batched storage-key hydration for workflow files.""" + +from __future__ import annotations + +import uuid +from collections.abc import Mapping, Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference, parse_file_reference +from dify_graph.file import File, FileTransferMethod +from models import ToolFile, UploadFile + + +class StorageKeyLoader: + """Load storage keys for files with a constant number of database queries.""" + + _session: Session + _tenant_id: str + _access_controller: FileAccessControllerProtocol + + def __init__( + self, + session: Session, + tenant_id: str, + access_controller: FileAccessControllerProtocol, + ) -> None: + self._session = session + self._tenant_id = tenant_id + self._access_controller = access_controller + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_upload_file_filters(stmt) + return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_tool_file_filters(stmt) + return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)} + + def load_storage_keys(self, files: Sequence[File]) -> None: + """Hydrate storage keys by loading their backing file rows in batches. + + The sequence shape is preserved. Each file is updated in place with a + canonical record reference and storage key loaded from an authorized + database row. Tenant scoping is enforced by this loader's context + rather than by embedding tenant identity or storage paths inside + graph-layer ``File`` values. + + For best performance, prefer batches smaller than 1000 files. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + ) + file.storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + ) + file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py new file mode 100644 index 0000000000..93cd4cd167 --- /dev/null +++ b/api/factories/file_factory/validation.py @@ -0,0 +1,44 @@ +"""Validation helpers for workflow file inputs.""" + +from __future__ import annotations + +from dify_graph.file import FileTransferMethod, FileType, FileUploadConfig + + +def is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) + # These are internally generated and should bypass user upload restrictions + if file_transfer_method == FileTransferMethod.TOOL_FILE: + return True + + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE: + if ( + config.image_config + and config.image_config.transfer_methods + and file_transfer_method not in config.image_config.transfer_methods + ): + return False + elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 14a56bf4a2..7dad008df9 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,75 +1,51 @@ +"""Compatibility factory for non-graph variable bootstrapping. + +Graph runtime segment/variable conversions live under `dify_graph.variables`. +This module keeps the application-layer mapping helpers and re-exports the +shared conversion functions for legacy callers and tests. +""" + from collections.abc import Mapping, Sequence from typing import Any, cast -from uuid import uuid4 from configs import dify_config -from dify_graph.constants import ( +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, ) -from dify_graph.file import File from dify_graph.variables.exc import VariableError -from dify_graph.variables.segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, +from dify_graph.variables.factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, ) from dify_graph.variables.types import SegmentType from dify_graph.variables.variables import ( - ArrayAnyVariable, ArrayBooleanVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, BooleanVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, VariableBase, ) - -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -# Define the constant -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} +__all__ = [ + "TypeMismatchError", + "UnsupportedSegmentTypeError", + "build_conversation_variable_from_mapping", + "build_environment_variable_from_mapping", + "build_pipeline_variable_from_mapping", + "build_segment", + "build_segment_with_type", + "segment_to_variable", +] def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: @@ -135,172 +111,3 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(VariableBase, result) - - -def build_segment(value: Any, /) -> Segment: - # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` - # below - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - elif len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_segment_factory: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - # Array types - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """ - Build a segment with explicit type checking. - - This function creates a segment from a value while enforcing type compatibility - with the specified segment_type. It provides stricter type validation compared - to the standard build_segment function. - - Args: - segment_type: The expected SegmentType for the resulting segment - value: The value to be converted into a segment - - Returns: - Segment: A segment instance of the appropriate type - - Raises: - TypeMismatchError: If the value type doesn't match the expected segment_type - - Special Cases: - - For empty list [] values, if segment_type is array[*], returns the corresponding array type - - Type validation is performed before segment creation - - Examples: - >>> build_segment_with_type(SegmentType.STRING, "hello") - StringSegment(value="hello") - - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) - ArrayStringSegment(value=[]) - - >>> build_segment_with_type(SegmentType.STRING, 123) - # Raises TypeMismatchError - """ - # Handle None values - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - # Handle empty list special case for array types - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - elif segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - elif segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - elif segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - elif segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - elif segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - # Type compatibility checking - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) - elif segment_type == SegmentType.NUMBER and inferred_type in ( - SegmentType.INTEGER, - SegmentType.FLOAT, - ): - segment_class = _segment_factory[inferred_type] - return segment_class(value_type=inferred_type, value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value_type=segment.value_type, - value=segment.value, - selector=list(selector), - ) diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index a5c7ddbb11..07d28bae99 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -311,7 +311,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValue) -> JSONValue: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 428f92ed33..7f042a6b7e 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -133,7 +133,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py index c08578981b..e0a6ec2cac 100644 --- a/api/libs/datetime_utils.py +++ b/api/libs/datetime_utils.py @@ -2,7 +2,7 @@ import abc import datetime from typing import Protocol -import pytz +import pytz # type: ignore[import-untyped] class _NowFunction(Protocol): diff --git a/api/libs/schedule_utils.py b/api/libs/schedule_utils.py index 1ab5f499e9..b80a5ea722 100644 --- a/api/libs/schedule_utils.py +++ b/api/libs/schedule_utils.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -import pytz +import pytz # type: ignore[import-untyped] from croniter import croniter diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9ea..93efd55e34 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -6,11 +6,8 @@ import sqlalchemy as sa from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType +from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 68ff37bcaa..6bf994001b 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,10 +3,11 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto +from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 @@ -26,6 +27,7 @@ from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string @@ -57,6 +59,32 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return resolved_tenant_id + + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id + + class EnabledConfig(TypedDict): enabled: bool @@ -1057,23 +1085,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `dify_graph.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1086,15 +1117,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1402,21 +1430,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1429,15 +1459,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1612,6 +1639,7 @@ class Message(Base): "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: @@ -1625,6 +1653,7 @@ class Message(Base): "url": message_file.url, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: @@ -1639,6 +1668,7 @@ class Message(Base): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) else: raise ValueError( diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 0000000000..b390b8106b --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 0000000000..2c73dd1558 --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from functools import lru_cache +from typing import Any + +from core.workflow.file_reference import parse_file_reference +from dify_graph.file import File, FileTransferMethod + + +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_stored_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_id: str, +) -> File: + """ + Canonicalize a persisted file payload against the current tenant context. + + Stored JSON rows can outlive file schema changes, so rebuild storage-backed + files through the workflow factory instead of trusting serialized metadata. + Pure external ``REMOTE_URL`` payloads without a backing upload row are + passed through because there is no server-owned record to rebind. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + mapping.pop("tenant_id", None) + record_id = resolve_file_record_id(mapping) + transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) + + if transfer_method == FileTransferMethod.TOOL_FILE and record_id: + mapping["tool_file_id"] = record_id + elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: + mapping["upload_file_id"] = record_id + elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: + mapping["datasource_file_id"] = record_id + + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + if isinstance(url, str) and url: + mapping["remote_url"] = url + return File.model_validate(mapping) + + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_get_file_access_controller(), + ) + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `dify_graph.file.File`. + """ + + transfer_method = FileTransferMethod.value_of(file_mapping["transfer_method"]) + record_id = resolve_file_record_id(file_mapping) + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id="", + ) + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=file_mapping, tenant_resolver=tenant_resolver) + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id=tenant_id, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 334ec42058..320e0eecd9 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -24,7 +24,8 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import ( +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) @@ -57,6 +58,7 @@ from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID +from .utils.file_input_compat import build_file_from_stored_mapping logger = logging.getLogger(__name__) @@ -64,6 +66,15 @@ SerializedWorkflowValue = dict[str, Any] SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] +def _resolve_workflow_app_tenant_id(app_id: str) -> str: + from .model import App + + tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return tenant_id + + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] features: dict[str, Any] @@ -273,7 +284,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -1565,10 +1576,9 @@ class WorkflowDraftVariable(Base): def _loads_value(self) -> Segment: value = json.loads(self.value) - return self.build_segment_with_type(self.value_type, value) + return self.build_segment_from_serialized_value(self.value_type, value) - @staticmethod - def rebuild_file_types(value: Any): + def _rebuild_file_types(self, value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to @@ -1582,13 +1592,72 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, + ) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + else: + return cast(Any, value) + + def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: + # Persisted draft variable rows may contain historical file payloads. + # Rebuild them through the file factory so tenant ownership, signed URLs, + # and storage-backed metadata come from canonical records instead of the + # serialized JSON blob. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + @staticmethod + def rebuild_file_types(value: Any): + # Keep the class-level fallback for callers that only need lightweight + # structural reconstruction. Persisted draft-variable payloads should go + # through `build_segment_from_serialized_value()` so file metadata is + # rebuilt from canonical storage records. + if isinstance(value, dict): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) + elif isinstance(value, list) and value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8..3595ea33f0 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31..a48e67b0e2 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -40,9 +40,9 @@ from typing import Protocol from sqlalchemy.orm import Session +from core.repositories.factory import WorkflowExecutionRepository from dify_graph.entities.pause_reason import PauseReason from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e4..75ea7bc102 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -43,7 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/services/app_service.py b/api/services/app_service.py index 69c7c0c95a..aba0256d12 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) @@ -197,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947..a4bd96b399 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 969ca68545..673b15c4e7 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -229,7 +229,7 @@ class DatasetService: raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -354,7 +354,7 @@ class DatasetService: def check_dataset_model_setting(dataset): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -371,7 +371,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -388,7 +388,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -409,7 +409,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -746,7 +746,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -864,7 +864,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -958,7 +958,7 @@ class DatasetService: dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -1000,7 +1000,7 @@ class DatasetService: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1053,7 +1053,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1912,7 +1912,7 @@ class DocumentService: dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2224,7 +2224,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -3129,7 +3129,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3212,7 +3212,7 @@ class SegmentService: with redis_client.lock(lock_name, timeout=600): embedding_model = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3350,7 +3350,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3413,7 +3413,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3454,7 +3454,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da..9ba58ac6ec 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f51..035a4b99a6 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -255,7 +255,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3ed..926e04d503 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -10,6 +10,7 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType from dify_graph.model_runtime.entities.provider_entities import ( @@ -26,8 +27,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +42,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +63,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +85,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -222,8 +224,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -310,7 +312,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -495,8 +497,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -532,6 +534,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -542,6 +545,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -552,6 +556,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -581,7 +586,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1a..fae3c66e2a 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,9 @@ import logging from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +25,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +44,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +61,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +139,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +147,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +412,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +497,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +530,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +546,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 296b9f0890..28047bcceb 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -34,25 +34,31 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, +) +from core.workflow.variable_pool_initializer import add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey +from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType from dify_graph.errors import WorkflowNodeRunFailedError from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent from dify_graph.graph_events.base import GraphNodeEventBase from dify_graph.node_events.base import NodeRunResult from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase +from dify_graph.variables.variables import Variable, VariableBase from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -88,6 +94,12 @@ from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -521,13 +533,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -959,10 +965,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1276,7 +1282,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1287,12 +1293,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index ed7a33feae..5d9821baf6 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -192,7 +192,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -201,7 +201,8 @@ class SummaryIndexService: ) if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) - embedding_tokens = tokens_list[0] if tokens_list else 0 + raw_embedding_tokens = tokens_list[0] if tokens_list else 0 + embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3c1a4cc747..a171c793fd 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -15,6 +15,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( @@ -46,6 +47,7 @@ except ImportError: magic = None # type: ignore[assignment] logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WebhookService: @@ -422,6 +424,7 @@ class WebhookService: return file_factory.build_from_mapping( mapping=mapping, tenant_id=webhook_trigger.tenant_id, + access_controller=_file_access_controller, ) @classmethod diff --git a/api/services/vector_service.py b/api/services/vector_service.py index bb94a03ba3..16a522d26d 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -47,7 +47,7 @@ class VectorService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f124e137c3..b839e26056 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,9 +14,16 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) +from dify_graph.enums import NodeType from dify_graph.file.models import File from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables @@ -36,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation from models.enums import ConversationFromSource, DraftVariableType +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -120,7 +128,11 @@ class DraftVarLoader(VariableLoader): elif isinstance(value, ArrayFileSegment): files.extend(value.value) with Session(bind=self._engine) as session: - storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader = StorageKeyLoader( + session, + tenant_id=self._tenant_id, + access_controller=DatabaseFileAccessController(), + ) storage_key_loader.load_storage_keys(files) offloaded_draft_vars = [] @@ -174,7 +186,7 @@ class DraftVarLoader(VariableLoader): return (draft_var.node_id, draft_var.name), variable deserialized = json.loads(content) - segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized) + segment = draft_var.build_segment_from_serialized_value(variable_file.value_type, deserialized) variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), @@ -838,6 +850,12 @@ class DraftVariableSaver: self._user = user self._enclosing_node_id = enclosing_node_id + def _resolve_app_tenant_id(self) -> str: + tenant_id = self._session.scalar(select(App.tenant_id).where(App.id == self._app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {self._app_id}") + return tenant_id + def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -892,27 +910,18 @@ class DraftVariableSaver: for name, value in output.items(): value_seg = _build_segment_for_serialized_values(value) node_id, name = self._normalize_variable_for_start_node(name) - # If node_id is not `sys`, it means that the variable is a user-defined input field - # in `Start` node. - if node_id != SYSTEM_VARIABLE_NODE_ID: - draft_vars.append( - WorkflowDraftVariable.new_node_variable( - app_id=self._app_id, - user_id=self._user.id, - node_id=self._node_id, - name=name, - node_execution_id=self._node_execution_id, - value=value_seg, - visible=True, - editable=True, - ) - ) - has_non_sys_variables = True - else: + if node_id == SYSTEM_VARIABLE_NODE_ID: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we - # just build files from the value. - files = [File.model_validate(v) for v in value] + # just rebuild files from the serialized payload. + tenant_id = self._resolve_app_tenant_id() + files = [ + build_file_from_stored_mapping( + file_mapping=v, + tenant_id=tenant_id, + ) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: @@ -928,15 +937,47 @@ class DraftVariableSaver: editable=self._should_variable_be_editable(node_id, name), ) ) + elif node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + user_id=self._user.id, + name=name, + value=value_seg, + ) + ) + has_non_sys_variables = True + else: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + user_id=self._user.id, + node_id=node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(node_id, self._node_type, name), + editable=self._should_variable_be_editable(node_id, name), + ) + ) + has_non_sys_variables = True if not has_non_sys_variables: draft_vars.append(self._create_dummy_output_variable()) return draft_vars def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: - if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return self._node_id, name - _, name_ = name.split(".", maxsplit=1) - return SYSTEM_VARIABLE_NODE_ID, name_ + for reserved_node_id in ( + SYSTEM_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + CONVERSATION_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + ): + prefix = f"{reserved_node_id}." + if name.startswith(prefix): + _, name_ = name.split(".", maxsplit=1) + return reserved_node_id, name_ + + return self._node_id, name def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars = [] diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..1872299b62 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -12,10 +12,20 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type +from core.workflow.human_input_compat import ( + DeliveryChannelConfig, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, +) from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry from dify_graph.entities import GraphInitParams, WorkflowNodeExecution from dify_graph.entities.graph_config import NodeConfigDict @@ -33,18 +43,11 @@ from dify_graph.node_events import NodeRunResult from dify_graph.nodes import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( - DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, -) +from dify_graph.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission from dify_graph.nodes.human_input.enums import HumanInputFormKind from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variable_loader import load_into_variable_pool from dify_graph.variables import VariableBase from dify_graph.variables.input_entities import VariableEntityType @@ -82,6 +85,8 @@ from .human_input_delivery_test_service import ( from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService from .workflow_restore import apply_published_workflow_snapshot_to_draft +_file_access_controller = DatabaseFileAccessController() + class WorkflowService: """ @@ -486,13 +491,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager - from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # Get model instance to validate provider+model combination - model_manager = ModelManager() - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -501,8 +508,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None @@ -607,11 +613,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -765,6 +770,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -772,11 +778,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -895,7 +903,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") @@ -995,17 +1002,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1055,7 +1065,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1070,9 +1080,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1138,7 +1147,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1155,11 +1164,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1422,7 +1433,7 @@ class WorkflowService: from dify_graph.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1511,38 +1522,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1567,7 +1588,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): raise ValueError(f"expected dict for file object, got {type(value)}") - return build_from_mapping(mapping=value, tenant_id=tenant_id) + return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): raise ValueError(f"expected list for file list object, got {type(value)}") @@ -1575,6 +1596,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia return [] if not isinstance(value[0], dict): raise ValueError(f"expected dict for first element in the file list, got {type(value)}") - return build_from_mappings(mappings=value, tenant_id=tenant_id) + return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 174aa50343..c386fd76a3 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -239,13 +239,18 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode + response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], + workflow_run_id: str, + app_mode: AppMode, ) -> None: topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) for event in response_stream: try: - payload = json.dumps(event) - except TypeError: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): logger.exception("error while encoding event") continue diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dd58378e0e..e6c4fc31a1 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -121,7 +121,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d241783359..ac5f873e61 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_database import db from extensions.ext_mail import mail diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5b..d20a5a2525 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,5 +1,5 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamCompletedEvent diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca1..62d1826475 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.app.file_access import DatabaseFileAccessController from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db from extensions.storage.storage_type import StorageType @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -192,19 +197,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -313,7 +315,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -337,7 +339,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -364,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a869691..c3b52adef9 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -6,7 +6,7 @@ import pytest from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.nodes import BuiltinNodeTypes from dify_graph.variables.segments import StringSegment from dify_graph.variables.types import SegmentType diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed1..a66fb0a6b1 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -4,8 +4,8 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType @@ -15,7 +15,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b866..989078bba6 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -6,13 +6,13 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import NodeRunResult from dify_graph.nodes.code.code_node import CodeNode from dify_graph.nodes.code.limits import CodeNodeLimits from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e55..e866b15bab 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -9,12 +9,13 @@ from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.graph import Graph from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -54,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -189,6 +191,7 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" + from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.http_request.entities import ( HttpRequestNodeAuthorization, @@ -198,11 +201,10 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): from dify_graph.nodes.http_request.exc import AuthorizationConfigError from dify_graph.nodes.http_request.executor import Executor from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -700,7 +702,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -728,6 +730,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1e..f2a1e0b53a 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -7,13 +7,15 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.node_events import StreamCompletedEvent +from dify_graph.nodes.llm.file_saver import LLMFileSaver from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -51,7 +53,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, @@ -66,6 +68,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +82,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -159,7 +167,7 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): + def mock_fetch_prompt_messages_1(**_kwargs): from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af0196..d8ce3fe98c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -5,12 +5,13 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -56,7 +57,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, @@ -77,6 +78,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c3..bd41a5fbf9 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -3,12 +3,12 @@ import uuid from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +66,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -90,7 +90,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 818ae46625..cdf09c1a7b 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -5,13 +5,14 @@ from unittest.mock import MagicMock, patch from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.node_events import StreamCompletedEvent from dify_graph.nodes.protocols import ToolFileManagerProtocol from dify_graph.nodes.tool.tool_node import ToolNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +41,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -64,6 +65,7 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index dd5aea399e..48bf3ca446 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -33,6 +33,9 @@ from extensions.ext_database import db logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" +SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" + class _CloserProtocol(Protocol): """_Closer is any type which implement the close() method.""" @@ -164,8 +167,10 @@ class DifyTestContainers: logger.info("Redis container is ready and accepting connections") # Start Dify Sandbox container for code execution environment. + # Default to the production-pinned image while allowing local overrides for debugging. logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.14").with_network(self.network) + sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) + self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -175,7 +180,12 @@ class DifyTestContainers: sandbox_port = self.dify_sandbox.get_exposed_port(8194) os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" - logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + logger.info( + "Dify Sandbox container started successfully - Image: %s Host: %s, Port: %s", + sandbox_image, + sandbox_host, + sandbox_port, + ) # Wait for Dify Sandbox to be ready logger.info("Waiting for Dify Sandbox to be ready to accept connections...") diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index f037ad77c0..32cad239e6 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -6,7 +6,7 @@ from flask.testing import FlaskClient from sqlalchemy import select from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.variables.segments import StringSegment from factories.variable_factory import segment_to_variable from models import Workflow diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea293..7107c3d56b 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -31,6 +31,7 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.enums import WorkflowExecutionStatus from dify_graph.graph_engine.entities.commands import GraphEngineCommand @@ -40,7 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -212,7 +213,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12..fc25972a3b 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,20 +7,17 @@ from uuid import uuid4 from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams +from dify_graph.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df3..1efb0db9c2 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -10,8 +10,11 @@ from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowType from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -23,9 +26,7 @@ from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -39,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -52,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb0..3c28dcc9ce 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -6,6 +6,7 @@ from uuid import uuid4 import pytest from sqlalchemy.orm import Session +from core.app.file_access import DatabaseFileAccessController from dify_graph.file import File, FileTransferMethod, FileType from extensions.ext_database import db from extensions.storage.storage_type import StorageType @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -193,19 +198,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -314,7 +316,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -338,7 +340,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -365,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 49b370990a..eb85ac4ca5 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,7 +2,6 @@ from __future__ import annotations -import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock @@ -16,16 +15,14 @@ from dify_graph.entities import WorkflowExecution from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType from dify_graph.enums import WorkflowExecutionStatus from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus +from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( - BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, - RecipientType, ) from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -636,12 +633,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_prefers_backstage_token_when_available( + def test_builds_reason_from_form_definition( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Use backstage token when multiple recipient types may exist.""" + """Build the graph pause reason from the stored form definition.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -668,25 +665,6 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() - delivery = HumanInputDelivery( - form_id=form_model.id, - delivery_method_type=DeliveryMethodType.WEBAPP, - channel_payload="{}", - ) - db_session_with_containers.add(delivery) - db_session_with_containers.flush() - - access_token = secrets.token_urlsafe(8) - recipient = HumanInputFormRecipient( - form_id=form_model.id, - delivery_id=delivery.id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - db_session_with_containers.add(recipient) - db_session_with_containers.flush() - # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -716,13 +694,12 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) - db_session_with_containers.refresh(recipient) - reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + reason = _build_human_input_required_reason(reason_model, form_model) assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token assert reason.node_title == "Ask Name" assert reason.form_content == "content" assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"} diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index ed998c9ed0..e9faa319c2 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -270,7 +270,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from dify_graph.nodes.human_input.enums import DeliveryMethodType + from core.workflow.human_input_compat import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index b51fbc3a42..9c36ab03c8 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e..33955d5d84 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index a83af30fb9..fa57dd4a6f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0702680f5c..2ef630ec88 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -174,7 +174,7 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( @@ -264,7 +264,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -297,7 +297,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 2899d5b8a5..aab96f6f4c 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -363,7 +363,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -458,7 +458,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -544,7 +544,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792ce..4c6718f959 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -4,14 +4,14 @@ from unittest.mock import MagicMock import pytest -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) +from dify_graph.enums import BuiltinNodeTypes +from dify_graph.nodes.human_input.entities import HumanInputNodeData from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode from models.workflow import Workflow, WorkflowType @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 85dc04b162..bdf6d9b951 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 989df42499..ca6e7afeab 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -18,11 +18,10 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch( - "services.model_load_balancing_service.ModelProviderFactory", autospec=True - ) as mock_model_provider_factory, + "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -46,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -61,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 6afc5aa43c..b3233852c5 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -18,8 +18,12 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, + patch( + "services.model_provider_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch( + "services.model_provider_service.create_plugin_model_provider_factory", autospec=True + ) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index d256c0d90b..70aa813142 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -20,7 +20,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6b95954480..f2307fbd7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -25,7 +25,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 880143013e..053886edad 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -31,7 +31,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa0..a22981ca93 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -2,7 +2,7 @@ import pytest from faker import Faker from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.segments import StringSegment from models import App, Workflow from models.enums import DraftVariableType diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 731770e01a..d02a078281 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index e3c0749494..21a1975879 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index d2e343ef52..f9ae33b32f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -54,7 +54,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f82..241bd9f9f8 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -9,15 +9,15 @@ from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) +from dify_graph.enums import WorkflowExecutionStatus +from dify_graph.nodes.human_input.entities import HumanInputNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b9..1095e79da8 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a257..530a4f5906 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -13,7 +13,7 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now @@ -316,7 +316,6 @@ def test_workflow_file_variable_with_signed_url(): # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -374,7 +373,6 @@ def test_workflow_file_variable_remote_url(): # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63..1f5f0646be 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -14,7 +14,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from dify_graph.variables.types import SegmentType from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index ff565f19fd..8555900f4e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -417,7 +417,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -521,7 +521,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -580,7 +580,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 94c3019d5e..44feacf2ad 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -4,7 +4,7 @@ from __future__ import annotations import builtins import importlib -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"): _CONTROLLER_MODULE: ModuleType | None = None _WRAPS_MODULE: ModuleType | None = None -_CONTROLLER_PATCHERS: list[patch] = [] @contextmanager @@ -37,6 +36,14 @@ def app() -> Flask: @pytest.fixture def controller_module(monkeypatch: pytest.MonkeyPatch): + """ + Import the controller with auth decorators neutralized only during import. + + The imported view classes retain those no-op decorators after import, so we + can restore the original globals immediately and avoid leaking auth patches + into unrelated tests such as libs.login unit coverage. + """ + module_name = "controllers.console.workspace.tool_providers" global _CONTROLLER_MODULE if _CONTROLLER_MODULE is None: @@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): ("controllers.console.wraps.is_admin_or_owner_required", _noop), ("controllers.console.wraps.enterprise_license_required", _noop), ] - for target, value in patch_targets: - patcher = patch(target, value) - patcher.start() - _CONTROLLER_PATCHERS.append(patcher) monkeypatch.setenv("DIFY_SETUP_READY", "true") - with _mock_db(): - _CONTROLLER_MODULE = importlib.import_module(module_name) + with ExitStack() as stack: + for target, value in patch_targets: + stack.enter_context(patch(target, value)) + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) module = _CONTROLLER_MODULE monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea..edb91c3f26 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 8fe41cd19f..910d781cd0 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -942,11 +942,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1044,12 +1044,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index aed1651511..a6a205d086 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -73,7 +73,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) return mock_manager @@ -109,7 +109,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -124,7 +124,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -141,7 +141,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -158,7 +158,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -183,7 +183,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -200,7 +200,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py index e2ba276d8e..68bca485bb 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -43,6 +43,17 @@ def valid_config(): class TestModelConfigManager: + @staticmethod + def _patch_model_assembly(mocker, *, provider_entities, model_list): + assembly = MagicMock() + assembly.model_provider_factory.get_providers.return_value = provider_entities + assembly.provider_manager.get_configurations.return_value.get_models.return_value = model_list + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) + return assembly + # ========================================================== # convert # ========================================================== @@ -97,11 +108,11 @@ class TestModelConfigManager: # ========================================================== def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) @@ -118,51 +129,39 @@ class TestModelConfigManager: def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): config = {"model": {"name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.provider is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.provider is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): config = {"model": {"provider": "openai/gpt", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.name is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = [] + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="must be in the specified model list"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) with pytest.raises(ValueError, match="must be in the specified model list"): ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -173,12 +172,7 @@ class TestModelConfigManager: model.model_properties = {} config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model] + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[model]) updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -186,12 +180,11 @@ class TestModelConfigManager: def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) with pytest.raises(ValueError, match="completion_params is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -212,16 +205,9 @@ class TestModelConfigManager: # Mock ModelProviderID to return formatted provider mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") mock_provider_id.return_value = "openai/gpt" - - # Mock provider factory - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") provider_entity = MagicMock() provider_entity.provider = "openai/gpt" - mock_factory.return_value.get_providers.return_value = [provider_entity] - - # Mock provider manager - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly(mocker, provider_entities=[provider_entity], model_list=valid_model_list) updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c7..5d241432a0 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -11,6 +11,19 @@ from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e2..079df0b4e6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231..c7cb5f142d 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea..e10973f271 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -42,11 +42,12 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk +from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -166,7 +167,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -311,7 +312,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -522,7 +523,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +557,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1e..e796303583 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -3,13 +3,15 @@ from types import SimpleNamespace import pytest from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.runtime import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a3347..433e7c6b0f 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -12,7 +12,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +222,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd0..643f863479 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -4,13 +4,13 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b..c1d7f4687e 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc4..4e406dc12a 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -24,9 +24,9 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f39..b2fba7a388 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced02394..4d303ce01e 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,5 +1,3 @@ -from unittest.mock import MagicMock - import pytest from core.app.apps.base_app_generator import BaseAppGenerator @@ -403,11 +401,11 @@ class TestBaseAppGeneratorExtras: monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mapping", - lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object", ) monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mappings", - lambda mappings, tenant_id, config: ["file-1", "file-2"], + lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"], ) user_inputs = { @@ -489,7 +487,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py index c6dc20ffc6..842d14bbd2 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -59,3 +59,18 @@ class TestBaseAppQueueManager: bad = SimpleNamespace(_sa_instance_state=True) with pytest.raises(TypeError): manager._check_for_sqlalchemy_models(bad) + + def test_stop_listen_defers_graph_runtime_state_cleanup_until_listener_exits(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = None + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + runtime_state = SimpleNamespace(name="runtime-state") + manager.graph_runtime_state = runtime_state + + manager.stop_listen() + + assert manager.graph_runtime_state is runtime_state + assert list(manager.listen()) == [] + assert manager.graph_runtime_state is None diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8..410136728b 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -8,6 +8,7 @@ from core.app.apps.advanced_chat import app_generator as adv_app_gen_module from core.app.apps.workflow import app_generator as wf_app_gen_module from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.entities.pause_reason import SchedulingPause @@ -29,7 +30,6 @@ from dify_graph.nodes.base.node import Node from dify_graph.nodes.end.entities import EndNodeData from dify_graph.nodes.start.entities import StartNodeData from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +162,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569..b4ba1293e9 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -1,6 +1,6 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from types import SimpleNamespace import pytest @@ -11,11 +11,16 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) +from core.workflow.system_variables import default_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( @@ -23,13 +28,18 @@ from dify_graph.graph_events import ( GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, NodeRunIterationSucceededEvent, NodeRunLoopFailedEvent, + NodeRunRetryEvent, NodeRunStartedEvent, NodeRunStreamChunkEvent, + NodeRunSucceededEvent, ) +from dify_graph.node_events import NodeRunResult from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.variables.variables import StringVariable class TestWorkflowBasedAppRunner: @@ -44,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -83,7 +93,7 @@ class TestWorkflowBasedAppRunner: def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -131,6 +141,96 @@ class TestWorkflowBasedAppRunner: assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init(self, monkeypatch): + variable_loader = SimpleNamespace( + load_variables=lambda selectors: ( + [ + StringVariable( + name="conversation_id", + value="conv-1", + selector=["sys", "conversation_id"], + ) + ] + if selectors + else [] + ) + ) + runner = WorkflowBasedAppRunner( + queue_manager=SimpleNamespace(), + variable_loader=variable_loader, + app_id="app", + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + + workflow = SimpleNamespace( + tenant_id="tenant", + id="workflow", + graph_dict={ + "nodes": [ + {"id": "loop-node", "data": {"type": "loop", "version": "1", "title": "Loop"}}, + { + "id": "llm-child", + "data": { + "type": "llm", + "version": "1", + "loop_id": "loop-node", + "memory": object(), + }, + }, + ], + "edges": [], + }, + ) + + class _LoopNodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + def _validate_node_config(value): + return {"id": value["id"], "data": SimpleNamespace(**value["data"])} + + def _graph_init(**kwargs): + variable_pool = graph_runtime_state.variable_pool + assert variable_pool.get(["sys", "conversation_id"]) is not None + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.NodeConfigDictAdapter.validate_python", + _validate_node_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + _graph_init, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.resolve_workflow_node_class", + lambda **_kwargs: _LoopNodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert graph is not None + assert variable_pool.get(["sys", "conversation_id"]).value == "conv-1" + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): published: list[object] = [] @@ -140,7 +240,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +283,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -195,7 +295,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.START, node_title="Start", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), ), ) runner._handle_event( @@ -232,7 +332,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Iter", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={"ok": True}, metadata={}, @@ -246,7 +346,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Loop", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={}, metadata={}, @@ -259,3 +359,87 @@ class TestWorkflowBasedAppRunner: assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) + + @pytest.mark.parametrize( + ("event_factory", "queue_event_cls"), + [ + ( + lambda result, start_at, finished_at: NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + node_run_result=result, + ), + QueueNodeSucceededEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeFailedEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeExceptionEvent, + ), + ( + lambda result, start_at, _finished_at: NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=start_at, + error="boom", + retry_index=1, + node_run_result=result, + ), + QueueNodeRetryEvent, + ), + ], + ) + def test_handle_start_node_result_events_project_outputs(self, event_factory, queue_event_cls): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + started_at = datetime.now(UTC) + finished_at = datetime.now(UTC) + result = NodeRunResult( + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + "conversation.session_id": "session-1", + }, + ) + + runner._handle_event(workflow_entry, event_factory(result, started_at, finished_at)) + + queue_event = published[-1] + assert isinstance(queue_event, queue_event_cls) + assert queue_event.outputs == {"question": "hello"} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..1456262416 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -9,15 +9,15 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity +from core.workflow.system_variables import default_system_variables from dify_graph.entities.graph_config import NodeConfigDictAdapter from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd6654..699822ad0c 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -10,13 +10,14 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse +from core.workflow.system_variables import build_system_variables from dify_graph.entities.pause_reason import HumanInputRequired from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph_events.graph import GraphRunPausedEvent from dify_graph.nodes.human_input.entities import FormInput, UserAction from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable from models.account import Account +from models.human_input import RecipientType class _RecordingWorkflowAppRunner(WorkflowAppRunner): @@ -74,7 +75,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", @@ -128,7 +128,21 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon class _FakeSession: def execute(self, _stmt): - return [("form-1", expiration_time)] + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def scalars(self, _stmt): + return [ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] def __enter__(self): return self @@ -146,10 +160,8 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), ], actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, node_id="node-id", node_title="Human Step", - form_token="token", ) queue_event = QueueWorkflowPausedEvent( reasons=[reason], @@ -170,7 +182,6 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert pause_resp.data.paused_nodes == ["node-id"] assert pause_resp.data.outputs == {} assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True assert isinstance(responses[0], HumanInputRequiredResponse) hi_resp = responses[0] @@ -180,4 +191,5 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert hi_resp.data.inputs[0].output_variable_name == "field" assert hi_resp.data.actions[0].id == "approve" assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "backstage-token" assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e71035..51472eb236 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -7,11 +7,12 @@ from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +38,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d207..cb7eb1169c 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -44,11 +44,12 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,29 +829,7 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d941..ba009ece6b 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -3,16 +3,15 @@ from datetime import datetime from unittest.mock import Mock from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from dify_graph.variables.segments import Segment, StringSegment class MockReadOnlyVariablePool: @@ -36,31 +35,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, + node_type=BuiltinNodeTypes.LLM, start_at=datetime.utcnow(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +74,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +104,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c..efe952a54b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -13,6 +13,7 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) +from core.workflow.system_variables import SystemVariableKey from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.graph_engine.entities.commands import GraphEngineCommand from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index f9755061d6..1717ecf95e 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -3,7 +3,9 @@ from types import SimpleNamespace from unittest.mock import Mock, patch from core.app.layers.trigger_post_layer import TriggerPostLayer +from core.workflow.system_variables import build_system_variables from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from dify_graph.runtime import VariablePool from models.enums import WorkflowTriggerStatus @@ -19,7 +21,7 @@ class TestTriggerPostLayer: ) runtime_state = SimpleNamespace( outputs={"answer": "ok"}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=12, ) @@ -58,7 +60,7 @@ class TestTriggerPostLayer: def test_on_event_handles_missing_trigger_log(self): runtime_state = SimpleNamespace( outputs={}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=0, ) @@ -89,7 +91,7 @@ class TestTriggerPostLayer: def test_on_event_ignores_non_status_events(self): runtime_state = SimpleNamespace( outputs={}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=0, ) diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 0000000000..033d22aa47 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,57 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 0f8a846d11..765accd5d3 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -8,7 +8,7 @@ from core.app.workflow.layers.persistence import ( WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType +from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType from dify_graph.node_events import NodeRunResult @@ -58,3 +58,42 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon assert node_execution.finished_at == event_finished_at assert node_execution.elapsed_time == 2.0 + + +def test_update_node_execution_projects_start_outputs() -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-2" + node_execution.node_type = BuiltinNodeTypes.START + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="start", + title="Start", + predecessor_node_id=None, + iteration_id=None, + loop_id=None, + created_at=node_execution.created_at, + ) + + layer._update_node_execution( + node_execution, + NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + }, + ), + WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + node_execution.update_from_mapping.assert_called_once_with( + inputs={"question": "hello"}, + process_data={}, + outputs={"question": "hello"}, + metadata={}, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index fb76f22a2a..529f25022f 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -1,43 +1,370 @@ -from unittest.mock import patch +from __future__ import annotations +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope +from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime +from core.workflow.file_reference import build_file_reference +from dify_graph.file import File, FileTransferMethod, FileType +from models import ToolFile, UploadFile -class TestDifyWorkflowFileRuntime: - def test_runtime_properties_and_helpers(self, monkeypatch): - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123) - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") +def _build_file( + *, + transfer_method: FileTransferMethod, + reference: str | None = None, + remote_url: str | None = None, + extension: str | None = None, +) -> File: + return File( + id="file-id", + type=FileType.IMAGE, + transfer_method=transfer_method, + reference=reference, + remote_url=remote_url, + filename="diagram.png", + extension=extension, + mime_type="image/png", + size=128, + ) - runtime = DifyWorkflowFileRuntime() - assert runtime.files_url == "http://files" - assert runtime.internal_files_url == "http://internal" - assert runtime.secret_key == "secret" - assert runtime.files_access_timeout == 123 - assert runtime.multimodal_send_format == "url" +def _build_runtime() -> DifyWorkflowFileRuntime: + return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()) - with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get: - mock_get.return_value = "response" - assert runtime.http_get("http://example", follow_redirects=False) == "response" - mock_get.assert_called_once_with("http://example", follow_redirects=False) - with patch("core.app.workflow.file_runtime.storage.load") as mock_load: - mock_load.return_value = b"data" - assert runtime.storage_load("path", stream=True) == b"data" - mock_load.assert_called_once_with("path", stream=True) +def test_resolve_file_url_returns_remote_url() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/diagram.png", + ) - with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign: - mock_sign.return_value = "signed" - assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed" - mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False) + assert runtime.resolve_file_url(file=file) == "https://example.com/diagram.png" - def test_bind_runtime_registers_instance(self): - with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set: - bind_dify_workflow_file_runtime() - mock_set.assert_called_once() - runtime = mock_set.call_args[0][0] - assert isinstance(runtime, DifyWorkflowFileRuntime) +def test_resolve_file_url_requires_file_reference() -> None: + runtime = _build_runtime() + file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None) + + with pytest.raises(ValueError, match="Missing file reference"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_requires_extension_for_tool_files() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=None, + ) + + with pytest.raises(ValueError, match="Missing file extension"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sign_tool_file = MagicMock(return_value="https://signed.example.com/file") + monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file) + runtime = _build_runtime() + + tool_file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=".png", + ) + datasource_file = _build_file( + transfer_method=FileTransferMethod.DATASOURCE_FILE, + reference=build_file_reference(record_id="datasource-file-id"), + extension=".png", + ) + + assert runtime.resolve_file_url(file=tool_file) == "https://signed.example.com/file" + assert runtime.resolve_file_url(file=datasource_file) == "https://signed.example.com/file" + assert sign_tool_file.call_count == 2 + + +def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr( + "core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", + "https://internal.example.com", + ) + + runtime = _build_runtime() + url = runtime.resolve_upload_file_url( + upload_file_id="upload-file-id", + as_attachment=True, + for_external=False, + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-file-id/file-preview" + assert query["as_attachment"] == ["true"] + assert query["timestamp"] == ["1700000000"] + + +def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60) + runtime = _build_runtime() + payload = "file-preview|upload-file-id|1700000000|nonce" + sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode() + + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is True + ) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000100) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is False + ) + + +def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: b"image-bytes") + + assert runtime.load_file_bytes(file=file) == b"image-bytes" + session.get.assert_called_with(UploadFile, "upload-file-id") + + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: "not-bytes") + with pytest.raises(ValueError, match="is not a bytes object"): + runtime.load_file_bytes(file=file) + + +def test_resolve_storage_key_ignores_encoded_reference_when_unscoped(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + session.get.assert_called_once_with(UploadFile, "upload-file-id") + + +def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key") + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id") + + +def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = None + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match="Upload file upload-file-id not found"): + runtime.resolve_upload_file_url(upload_file_id="upload-file-id") + + +@pytest.mark.parametrize( + ("transfer_method", "record_id", "expected_storage_key"), + [ + (FileTransferMethod.LOCAL_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.DATASOURCE_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.TOOL_FILE, "tool-file-id", "tool-storage-key"), + ], +) +def test_resolve_storage_key_loads_database_records( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + record_id: str, + expected_storage_key: str, +) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + + def get(model_class, value): + if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}: + assert model_class is UploadFile + return SimpleNamespace(key="upload-storage-key") + assert model_class is ToolFile + return SimpleNamespace(file_key="tool-storage-key") + + session.get.side_effect = get + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == expected_storage_key + + +@pytest.mark.parametrize( + ("transfer_method", "expected_message"), + [ + (FileTransferMethod.LOCAL_FILE, "Upload file upload-file-id not found"), + (FileTransferMethod.TOOL_FILE, "Tool file tool-file-id not found"), + ], +) +def test_resolve_storage_key_raises_when_records_are_missing( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + expected_message: str, +) -> None: + runtime = _build_runtime() + record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id" + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + session.get.return_value = None + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match=expected_message): + runtime._resolve_storage_key(file=file) + + +def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + runtime = _build_runtime() + + assert runtime.multimodal_send_format == "url" + + with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch.object(file_runtime.storage, "load", return_value=b"data") as mock_load: + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + +def test_bind_dify_workflow_file_runtime_registers_runtime(monkeypatch: pytest.MonkeyPatch) -> None: + set_runtime = MagicMock() + monkeypatch.setattr(file_runtime, "set_workflow_file_runtime", set_runtime) + + bind_dify_workflow_file_runtime() + + set_runtime.assert_called_once() + assert isinstance(set_runtime.call_args.args[0], DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 9e742507c6..7f0d4a3014 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -131,7 +131,7 @@ class TestDifyNodeFactory: node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) assert isinstance(node, DummyTemplateTransformNode) - assert "template_renderer" in node.kwargs + assert "jinja2_template_renderer" in node.kwargs def test_create_node_http_request_branch(self, monkeypatch): factory = self._factory(monkeypatch) diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 45f6a0c7a1..cedac5b0fe 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -7,11 +7,11 @@ import pytest from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution from dify_graph.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, WorkflowType, @@ -34,7 +34,6 @@ from dify_graph.graph_events.node import ( ) from dify_graph.node_events import NodeRunResult from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.system_variable import SystemVariable class _RepoRecorder: @@ -54,13 +53,16 @@ def _naive_utc_now() -> datetime: def _make_layer( - system_variable: SystemVariable | None = None, + system_variables: list | None = None, *, extras: dict | None = None, trace_manager: object | None = None, ): - system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id") - runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0) + system_variables = system_variables or build_system_variables( + workflow_execution_id="run-id", + conversation_id="conv-id", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) application_generate_entity = WorkflowAppGenerateEntity.model_construct( @@ -115,8 +117,7 @@ class TestWorkflowPersistenceLayer: assert layer._node_sequence == 0 def test_get_execution_id_requires_system_variable(self): - system_variable = SystemVariable(workflow_execution_id=None) - layer, _, _, _ = _make_layer(system_variable) + layer, _, _, _ = _make_layer(build_system_variables()) with pytest.raises(ValueError, match="workflow_execution_id must be provided"): layer._get_execution_id() diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37..eb1599bacc 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d5eeae912c..e01be9ed28 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -7,6 +7,7 @@ from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError +from core.workflow.file_reference import parse_file_reference from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.file import File from dify_graph.file.enums import FileTransferMethod, FileType @@ -428,11 +429,8 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): return fake_tool_file mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) - mocker.patch( - "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE - ) + mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -533,7 +531,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", @@ -664,6 +661,8 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + assert parse_file_reference(f.reference).storage_key is None + assert f.storage_key == "k" def test_get_upload_file_by_id_raises_when_missing(mocker): diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f1..23ae91fdf2 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -449,6 +455,33 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch( + "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory + ) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") @@ -475,7 +508,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +722,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1067,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1083,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1575,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1697,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320..a06907ce4b 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -15,18 +15,17 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" assert file.filename == "image.png" assert file.extension == ".png" assert file.mime_type == "image/png" assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_accepts_legacy_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,8 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception file = File.model_validate(data) - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" + assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f..33a12528c0 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -13,7 +13,7 @@ from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, Inv class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000..766280dbc7 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,420 @@ +from unittest.mock import Mock + +import pytest + +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel +from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel +from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel +from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _build_model(model: str, model_type: ModelType) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _build_provider( + *, + provider: str, + provider_name: str, + supported_model_types: list[ModelType], + models: list[AIModelEntity] | None = None, + provider_credential_schema: ProviderCredentialSchema | None = None, + model_credential_schema: ModelCredentialSchema | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + provider_name=provider_name, + label=I18nObject(en_US=provider_name or provider), + supported_model_types=supported_model_types, + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + provider_credential_schema=provider_credential_schema, + model_credential_schema=model_credential_schema, + ) + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + self.validate_provider_credentials = Mock() + self.validate_model_credentials = Mock() + self.get_model_schema = Mock() + self.get_provider_icon = Mock() + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + + +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + +def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ), + _build_provider( + provider="langgenius/cohere/cohere", + provider_name="cohere", + supported_model_types=[ModelType.RERANK], + models=[_build_model("rerank-v3", ModelType.RERANK)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai", model_type=ModelType.LLM) + + assert len(results) == 1 + assert results[0].provider == "langgenius/openai/openai" + assert [model.model for model in results[0].models] == ["gpt-4o-mini"] + + +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + +def test_model_provider_factory_validates_provider_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ] + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.provider_credentials_validate( + provider="openai", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_provider_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + +def test_model_provider_factory_validates_model_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_model_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + runtime.get_model_schema.return_value = "schema" + runtime.get_provider_icon.return_value = (b"icon", "image/png") + factory = ModelProviderFactory(model_runtime=runtime) + + assert ( + factory.get_model_schema( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials=None, + ) + == "schema" + ) + assert factory.get_provider_icon("openai", "icon_small", "en_US") == (b"icon", "image/png") + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + runtime.get_provider_icon.assert_called_once_with( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + +@pytest.mark.parametrize( + ("model_type", "expected_type"), + [ + (ModelType.LLM, LargeLanguageModel), + (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), + (ModelType.RERANK, RerankModel), + (ModelType.SPEECH2TEXT, Speech2TextModel), + (ModelType.MODERATION, ModerationModel), + (ModelType.TTS, TTSModel), + ], +) +def test_model_provider_factory_builds_model_type_instances( + model_type: ModelType, + expected_type: type[object], +) -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] + ) + ) + + instance = factory.get_model_type_instance("openai", model_type) + + assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e7..3a97ad5c5d 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa7..98076c0b99 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b68..4ce9e22fd7 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562c..33adfc139b 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -680,7 +680,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f..88706b88a4 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -565,7 +565,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d5109..575f2b1109 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -657,7 +657,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639f..7fdfdc592a 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 7660967183..ad9d0846be 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad5..221ba49d3b 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 0000000000..7491e79f30 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 0000000000..59a9c229d0 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,61 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary +from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 0000000000..a2180bc6ba --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,506 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import Mock, sentinel + +import pytest + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl import model_runtime as model_runtime_module +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + + +def _build_model_schema() -> AIModelEntity: + return AIModelEntity( + model="gpt-4o-mini", + label=I18nObject(en_US="GPT-4o mini"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + + +def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=schema.model_dump_json()), + delete=Mock(), + setex=Mock(), + ), + ) + + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + client.get_model_schema.assert_not_called() + + +def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + delete = Mock() + setex = Mock() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value="not-json"), + delete=delete, + setex=setex, + ), + ) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) + client.get_model_schema.return_value = schema + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + delete.assert_called_once() + client.get_model_schema.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model_type=ModelType.LLM.value, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + setex.assert_called_once() + + +def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert ( + runtime.get_llm_num_tokens( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[], + tools=None, + ) + == 0 + ) + client.get_llm_num_tokens.assert_not_called() + + +def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=I18nObject(en_US="logo.svg"), + icon_small_dark=I18nObject(en_US="logo-dark.png"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + fetch_asset = Mock(return_value=b"") + monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + icon_bytes, mime_type = runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + assert icon_bytes == b"" + assert mime_type == "image/svg+xml" + fetch_asset.assert_called_once_with(tenant_id="tenant", id="logo.svg") + + +def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + with pytest.raises(ValueError, match="does not have small dark icon"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small_dark", + lang="en_US", + ) + + with pytest.raises(ValueError, match="Unsupported icon type"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_large", + lang="en_US", + ) + + +def test_get_schema_cache_key_is_stable_across_credential_order() -> None: + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + + first = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"b": "2", "a": "1"}, + ) + second = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1", "b": "2"}, + ) + + assert first == second + + +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + +def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" + + with pytest.raises(ValueError, match="Invalid provider"): + runtime._get_provider_schema("missing") diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 538457ccc8..87ee35faf8 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,7 +1,5 @@ from unittest.mock import MagicMock, patch -import pytest - from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType @@ -56,7 +54,6 @@ class TestDataPostProcessor: documents=original_documents, score_threshold=0.3, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -65,7 +62,6 @@ class TestDataPostProcessor: original_documents, 0.3, 2, - "user-1", QueryType.IMAGE_QUERY, ) processor.reorder_runner.run.assert_called_once_with(reranked_documents) @@ -176,25 +172,24 @@ class TestDataPostProcessor: processor = DataPostProcessor.__new__(DataPostProcessor) assert processor._get_rerank_model_instance("tenant-1", None) is None - def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self): + def test_get_rerank_model_instance_returns_none_for_incomplete_config(self): processor = DataPostProcessor.__new__(DataPostProcessor) - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value - with pytest.raises(KeyError, match="reranking_model_name"): - processor._get_rerank_model_instance( - tenant_id="tenant-1", - reranking_model={"reranking_provider_name": "provider-x"}, - ) + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) - manager_instance.get_model_instance.assert_not_called() + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") def test_get_rerank_model_instance_success(self): processor = DataPostProcessor.__new__(DataPostProcessor) model_instance = object() - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value manager_instance.get_model_instance.return_value = model_instance result = processor._get_rerank_model_instance( @@ -206,6 +201,7 @@ class TestDataPostProcessor: ) assert result is model_instance + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") manager_instance.get_model_instance.assert_called_once_with( tenant_id="tenant-1", provider="provider-x", @@ -216,8 +212,8 @@ class TestDataPostProcessor: def test_get_rerank_model_instance_handles_authorization_error(self): processor = DataPostProcessor.__new__(DataPostProcessor) - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") result = processor._get_rerank_model_instance( @@ -229,6 +225,7 @@ class TestDataPostProcessor: ) assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") class TestReorderRunner: diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 8c1e4e478b..63de4b8af2 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -399,7 +399,7 @@ class TestRetrievalServiceInternals: assert exceptions == [] vector_instance.search_by_file.assert_not_called() - @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") @patch("core.rag.datasource.retrieval_service.DataPostProcessor") @patch("core.rag.datasource.retrieval_service.Vector") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -451,9 +451,10 @@ class TestRetrievalServiceInternals: assert all_documents == reranked_docs assert exceptions == [] processor_instance.invoke.assert_called_once() + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) model_manager.check_model_support_vision.assert_called_once() - @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") @patch("core.rag.datasource.retrieval_service.DataPostProcessor") @patch("core.rag.datasource.retrieval_service.Vector") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -503,6 +504,7 @@ class TestRetrievalServiceInternals: assert all_documents == original_docs assert exceptions == [] + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) processor_instance.invoke.assert_not_called() @patch("core.rag.datasource.retrieval_service.DataPostProcessor") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index dd536af759..54ad6d330b 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -384,7 +384,8 @@ def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatc model_manager = MagicMock() model_manager.get_model_instance.return_value = "model-instance" - monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager)) + for_tenant_mock = MagicMock(return_value=model_manager) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) @@ -397,6 +398,7 @@ def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatc result = vector._get_embeddings() assert result == "cached-embedding" + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") model_manager.get_model_instance.assert_called_once_with( tenant_id="tenant-1", provider="openai", diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0..3ba0628fe2 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,7 +163,7 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d..a67cd0ddff 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -64,7 +65,7 @@ class TestCacheEmbeddingMultimodalDocuments: def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): """Test embedding a single multimodal document when cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: @@ -316,13 +317,14 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance def test_embed_multimodal_query_cache_miss(self, mock_model_instance): """Test embedding multimodal query when Redis cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) document = {"file_id": "file123"} vector = np.random.randn(1536) @@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -532,24 +535,13 @@ class TestCacheEmbeddingQueryErrors: class TestCacheEmbeddingInitialization: """Test suite for CacheEmbedding initialization.""" - def test_initialization_with_user(self): - """Test CacheEmbedding initialization with user parameter.""" - model_instance = Mock() - model_instance.model = "test-model" - model_instance.provider = "test-provider" - - cache_embedding = CacheEmbedding(model_instance, user="test-user") - - assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" - - def test_initialization_without_user(self): - """Test CacheEmbedding initialization without user parameter.""" + def test_initialization_sets_model_instance(self): + """Test CacheEmbedding initialization stores the provided model instance.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 6e71f0c61f..e1c9671113 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -134,7 +134,7 @@ class TestCacheEmbeddingDocuments: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Python is a programming language"] # Mock database query to return no cached embedding (cache miss) @@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments: # Verify model was invoked with correct parameters mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=texts, - user="test-user", input_type=EmbeddingInputType.DOCUMENT, ) @@ -612,7 +611,7 @@ class TestCacheEmbeddingQuery: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is Python?" # Create embedding result @@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery: # Verify model was invoked with QUERY input type mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user="test-user", input_type=EmbeddingInputType.QUERY, ) @@ -1568,25 +1566,16 @@ class TestEmbeddingEdgeCases: norm = np.linalg.norm(emb) assert abs(norm - 1.0) < 0.01 - def test_embed_query_with_user_context(self, mock_model_instance): - """Test query embedding with user context parameter. + def test_embed_query_uses_bound_model_instance(self, mock_model_instance): + """Test query embedding using the provided model instance. Verifies: - - User parameter is passed correctly to model - - User context is used for tracking/logging - - Embedding generation works with user context - - Context: - -------- - The user parameter is important for: - 1. Usage tracking per user - 2. Rate limiting per user - 3. Audit logging - 4. Personalization (in some models) + - Embedding generation works with the injected model instance + - Query input type is preserved + - No extra binding step is required at call time """ # Arrange - user_id = "user-12345" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is machine learning?" # Create embedding @@ -1620,24 +1609,20 @@ class TestEmbeddingEdgeCases: assert isinstance(result, list) assert len(result) == 1536 - # Verify user parameter was passed to model mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user=user_id, input_type=EmbeddingInputType.QUERY, ) - def test_embed_documents_with_user_context(self, mock_model_instance): - """Test document embedding with user context parameter. + def test_embed_documents_uses_bound_model_instance(self, mock_model_instance): + """Test document embedding using the provided model instance. Verifies: - - User parameter is passed correctly for document embeddings - - Batch processing maintains user context - - User tracking works across batches + - Batch processing uses the injected model instance + - Document input type is preserved """ # Arrange - user_id = "user-67890" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Document 1", "Document 2"] # Create embeddings @@ -1673,10 +1658,8 @@ class TestEmbeddingEdgeCases: # Assert assert len(result) == 2 - # Verify user parameter was passed mock_model_instance.invoke_text_embedding.assert_called_once() call_args = mock_model_instance.invoke_text_embedding.call_args - assert call_args.kwargs["user"] == user_id assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 2c234edd9a..5f992034fa 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -400,7 +400,9 @@ class TestParagraphIndexProcessor: model_instance.invoke_llm.return_value = self._llm_result("text summary") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -411,7 +413,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, usage = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -434,7 +436,9 @@ class TestParagraphIndexProcessor: image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -449,7 +453,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, _ = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -470,7 +474,9 @@ class TestParagraphIndexProcessor: image_file = SimpleNamespace() with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -487,7 +493,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): ParagraphIndexProcessor.generate_summary( "tenant-1", diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b54a74b69c..786b054c19 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -482,7 +482,8 @@ class TestIndexingRunnerTransform: # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [ @@ -509,7 +510,7 @@ class TestIndexingRunnerTransform: assert len(result) == 2 assert result[0].page_content == "Chunk 1" assert result[1].page_content == "Chunk 2" - runner.model_manager.get_model_instance.assert_called_once_with( + model_manager.get_model_instance.assert_called_once_with( tenant_id=sample_dataset.tenant_id, provider=sample_dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -521,6 +522,7 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() + model_manager = mock_dependencies["model_manager"].return_value sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -539,14 +541,15 @@ class TestIndexingRunnerTransform: # Assert assert len(result) == 1 - runner.model_manager.get_model_instance.assert_not_called() + model_manager.get_model_instance.assert_not_called() def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs): """Test transformation with custom segmentation rules.""" # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})] @@ -586,7 +589,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -645,7 +648,8 @@ class TestIndexingRunnerLoad: runner = IndexingRunner() mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -664,7 +668,7 @@ class TestIndexingRunnerLoad: runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents) # Assert - runner.model_manager.get_model_instance.assert_called_once() + model_manager.get_model_instance.assert_called_once() # Verify executor was used for parallel processing assert mock_executor_instance.submit.called @@ -714,7 +718,8 @@ class TestIndexingRunnerLoad: mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -754,7 +759,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1..cd9d403073 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -352,12 +352,14 @@ class TestRerankModelRunner: # Assert: Empty result is returned assert len(result) == 0 - def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents): - """Test that user parameter is passed to model invocation. + def test_run_uses_bound_model_instance( + self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager + ): + """Test that rerank uses the bound model instance directly. Verifies: - - User ID is correctly forwarded to the model - - Model receives all expected parameters + - The injected model instance is used for invocation + - No late rebinding occurs through ModelManager.get_model_instance """ # Arrange: Mock rerank result mock_rerank_result = RerankResult( @@ -368,16 +370,18 @@ class TestRerankModelRunner: ) mock_model_instance.invoke_rerank.return_value = mock_rerank_result - # Act: Run reranking with user parameter + # Act: Run reranking result = rerank_runner.run( query="test", documents=sample_documents, - user="user123", ) - # Assert: User parameter is passed to model + # Assert: The injected model instance is invoked directly. + assert len(result) == 1 + mock_model_manager.return_value.get_model_instance.assert_not_called() call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs - assert call_kwargs["user"] == "user123" + assert call_kwargs["query"] == "test" + assert "user" not in call_kwargs class _ForwardingBaseRerankRunner(BaseRerankRunner): @@ -387,7 +391,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: return super().run( @@ -395,7 +398,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents=documents, score_threshold=score_threshold, top_n=top_n, - user=user, query_type=query_type, ) @@ -424,7 +426,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +443,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -539,8 +541,10 @@ class TestRerankModelRunnerMultimodal: ) mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + session = MagicMock() + session.query.return_value = query_chain with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), ): result, unique_documents = rerank_runner.fetch_multimodal_rerank( @@ -548,7 +552,6 @@ class TestRerankModelRunnerMultimodal: documents=[text_doc], score_threshold=0.2, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -557,7 +560,7 @@ class TestRerankModelRunnerMultimodal: invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE assert invoke_kwargs["docs"][0]["content"] == "text-content" - assert invoke_kwargs["user"] == "user-1" + assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): query_chain = Mock() @@ -595,7 +598,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1148,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1260,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1530,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1601,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1676,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1718,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1827,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index a34ca330ca..2e299201ef 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -38,6 +38,7 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrieval from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset +from models.enums import CreatorUserRole # ==================== Helper Functions ==================== @@ -3747,6 +3748,24 @@ class TestDatasetRetrievalAdditionalHelpers: mock_session.add_all.assert_called() mock_session.commit.assert_called() + def test_on_query_normalizes_workflow_end_user_role(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query="python", + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="end-user", + user_id="u1", + ) + + mock_session.add_all.assert_called_once() + added_queries = mock_session.add_all.call_args.args[0] + + assert len(added_queries) == 1 + assert added_queries[0].created_by_role == CreatorUserRole.END_USER + mock_session.commit.assert_called_once() + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: usage = LLMUsage.empty_usage() chunk_1 = SimpleNamespace( @@ -3836,7 +3855,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -4222,11 +4241,12 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4279,9 +4299,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4294,8 +4318,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4312,12 +4386,17 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4402,7 +4481,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), @@ -4413,7 +4492,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index e429563739..206540de1b 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -5,6 +5,7 @@ from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFini from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.model_runtime.entities.message_entities import PromptMessageRole +from dify_graph.model_runtime.entities.model_entities import ModelType class TestReactMultiDatasetRouter: @@ -87,6 +88,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +110,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -162,7 +165,11 @@ class TestReactMultiDatasetRouter: model_instance = Mock() model_instance.invoke_llm.return_value = iter([chunk]) - with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + with ( + patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager, + patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance text, returned_usage = router._invoke_llm( completion_param={"temperature": 0.1}, model_instance=model_instance, @@ -174,6 +181,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802..00814e1bc5 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -11,12 +11,12 @@ from uuid import uuid4 import pytest from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from dify_graph.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed0307..48327c3913 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d12664..7dde4c5c77 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -14,13 +14,15 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, +) +from dify_graph.nodes.human_input.entities import ( + FormDefinition, UserAction, ) from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 4116e8b4a5..4db29b1b2e 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -11,6 +11,8 @@ from unittest.mock import MagicMock import pytest from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, HumanInputFormRecord, HumanInputFormRepositoryImpl, HumanInputFormSubmissionRepository, @@ -19,18 +21,16 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) +from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType @@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None: assert entity.token == "tok" -def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: form = _DummyForm( id="f1", workflow_run_id="run", @@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No ) entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] - assert entity.web_app_token == "ctok" + assert entity.submission_token == "ctok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] - assert entity.web_app_token == "wtok" + assert entity.submission_token == "wtok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] - assert entity.web_app_token is None + assert entity.submission_token is None def test_form_entity_submitted_data_parsed() -> None: @@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), subject="s", body="b", @@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc session=MagicMock(), form_id="f", delivery_id="d", - recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), ) assert recipients == ["ok"] @@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m form_id="f", delivery_id="d", recipients_config=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), ) assert recipients == ["ok"] @@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - assert repo.get_form("run", "node") is None + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None form = _DummyForm( id="f1", @@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - entity = repo.get_form("run", "node") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") assert entity is not None assert entity.id == "f1" assert entity.recipients[0].id == "r1" @@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M session = _FakeSession() _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) form_config = HumanInputNodeData( title="Title", @@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M user_actions=[UserAction(id="submit", title="Submit")], ) params = FormCreateParams( - app_id="app", - workflow_execution_id="run", + workflow_execution_id=None, node_id="node", form_config=form_config, rendered_content="

hello

", @@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M display_in_ui=True, resolved_default_values={}, form_kind=HumanInputFormKind.RUNTIME, - console_recipient_required=True, - console_creator_account_id="acc-1", - backstage_recipient_required=True, ) entity = repo.create_form(params) assert entity.id == "form-id" assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) # Console token should take precedence when console recipient is present. - assert entity.web_app_token == "token-console" + assert entity.submission_token == "token-console" assert len(entity.recipients) == 3 diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 73de15e2cf..481487971a 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -15,6 +15,7 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from configs import dify_config +from core.repositories.factory import OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, _deterministic_json_dump, @@ -28,7 +29,6 @@ from dify_graph.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom @@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> lambda max_workers: FakeExecutor(), ) - result = repo.get_by_workflow_run("run", order_config=None) + result = repo.get_by_workflow_execution("run", order_config=None) assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb..6caf7bc004 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,12 +1,26 @@ -from unittest.mock import Mock, PropertyMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType -from models.provider import LoadBalancingModelConfig, ProviderModelSetting +from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel +from models.provider_ids import ModelProviderID + + +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + +def _build_session_context(session: Mock) -> MagicMock: + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return session_cm @pytest.fixture @@ -28,7 +42,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +83,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +103,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +133,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +151,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +190,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +208,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +218,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), @@ -228,3 +242,345 @@ def test_get_default_model_uses_first_available_active_model(): assert saved_default_model.model_name == "gpt-3.5-turbo" assert saved_default_model.provider_name == "openai" mock_session.commit.assert_called_once() + + +def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture): + mock_session = Mock() + mock_session.scalar.return_value = None + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is None + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_factory_cls.assert_not_called() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture): + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="openai", + model_name="gpt-4", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result is not None + assert result.model == "gpt-4" + assert result.provider.provider == "openai" + + +def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_records = {"openai": [SimpleNamespace(provider_name="openai")]} + provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]} + preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")} + + with ( + patch.object(manager, "_get_all_providers", return_value=provider_records), + patch.object(manager, "_init_trial_provider_records", return_value=provider_records), + patch.object(manager, "_get_all_provider_models", return_value=provider_model_records), + patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_providers.return_value = [] + + result = manager.get_configurations("tenant-id") + + expected_alias = str(ModelProviderID("openai")) + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result.tenant_id == "tenant-id" + assert expected_alias in provider_records + assert expected_alias in provider_model_records + assert expected_alias in preferred_provider_records + + +@pytest.mark.parametrize( + ("provider_name", "expected_provider_names"), + [ + ("openai", ["openai", "langgenius/openai/openai"]), + ("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]), + ("langgenius/gemini/google", ["langgenius/gemini/google", "google"]), + ], +) +def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]): + assert ProviderManager._get_provider_names(provider_name) == expected_provider_names + + +def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + model_type_instance = Mock() + provider_configuration.get_model_type_instance.return_value = model_type_instance + expected_bundle = Mock() + + with ( + patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, + ): + result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM) + mock_bundle.assert_called_once_with( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + assert result is expected_bundle + + +def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == (None, None) + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False) + + +def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-4", provider=Mock(provider="openai")), + Mock(model="gpt-4o", provider=Mock(provider="openai")), + ] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == ("openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_model(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + + +def test_update_default_model_record_updates_existing_record(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")] + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="anthropic", + model_name="claude-3-sonnet", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + assert result is existing_default_model + assert existing_default_model.provider_name == "openai" + assert existing_default_model.model_name == "gpt-3.5-turbo" + mock_session.commit.assert_called_once() + mock_session.add.assert_not_called() + + +def test_update_default_model_record_creates_record_with_origin_model_type(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + mock_session = Mock() + mock_session.scalar.return_value = None + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + mock_session.add.assert_called_once() + created_default_model = mock_session.add.call_args.args[0] + assert result is created_default_model + assert created_default_model.tenant_id == "tenant-id" + assert created_default_model.provider_name == "openai" + assert created_default_model.model_name == "gpt-4" + assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + mock_session.commit.assert_called_once() + + +def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None: + session = Mock() + openai_provider = SimpleNamespace(provider_name="openai") + gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google") + session.scalars.return_value = [openai_provider, gemini_provider] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_providers("tenant-id") + + assert list(result[str(ModelProviderID("openai"))]) == [openai_provider] + assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider] + + +@pytest.mark.parametrize( + "method_name", + [ + "_get_all_provider_models", + "_get_all_provider_model_settings", + "_get_all_provider_model_credentials", + ], +) +def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None: + session = Mock() + openai_primary = SimpleNamespace(provider_name="openai") + openai_secondary = SimpleNamespace(provider_name="openai") + anthropic_record = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = getattr(ProviderManager, method_name)("tenant-id") + + assert list(result["openai"]) == [openai_primary, openai_secondary] + assert list(result["anthropic"]) == [anthropic_record] + + +def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None: + session = Mock() + openai_preference = SimpleNamespace(provider_name="openai") + anthropic_preference = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_preference, anthropic_preference] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_preferred_model_providers("tenant-id") + + assert result == { + "openai": openai_preference, + "anthropic": anthropic_preference, + } + + +def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None: + with ( + patch("core.provider_manager.redis_client.get", return_value=b"False"), + patch("core.provider_manager.FeatureService.get_features") as mock_get_features, + patch("core.provider_manager.Session") as mock_session_cls, + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + assert result == {} + mock_get_features.assert_not_called() + mock_session_cls.assert_not_called() + + +def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: + session = Mock() + openai_config = SimpleNamespace(provider_name="openai") + anthropic_config = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_config, anthropic_config] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.redis_client.get", return_value=None), + patch("core.provider_manager.redis_client.setex") as mock_setex, + patch( + "core.provider_manager.FeatureService.get_features", + return_value=SimpleNamespace(model_load_balancing_enabled=True), + ), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") + assert list(result["openai"]) == [openai_config] + assert list(result["anthropic"]) == [anthropic_config] diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34..ccf6ddccaf 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5b..a3b03fefd6 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,6 +1,8 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text @@ -186,13 +194,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index a5242a78c5..353988d7a6 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse import pytest -from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature +from core.tools.signature import ( + get_signed_file_url_for_plugin, + sign_tool_file, + sign_upload_file, + verify_plugin_file_signature, + verify_tool_file_signature, +) def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: @@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] + + +def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload/for-plugin" + assert query["tenant_id"] == ["tenant-id"] + assert query["user_id"] == ["user-id"] + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is True + ) + + +def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + query = parse_qs(urlparse(url).query) + + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100) + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is False + ) diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd6..b94834d8bc 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -14,6 +14,7 @@ import httpx import pytest from core.tools.tool_file_manager import ToolFileManager +from dify_graph.file import FileTransferMethod def _setup_tool_file_signing(monkeypatch: pytest.MonkeyPatch) -> dict[str, str]: @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e22654..844bc01e29 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -421,7 +422,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +438,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +475,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +485,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +543,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f..6454a5bcd1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b2..8aba05ab4c 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cc00f79698..ff401ab358 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -24,7 +24,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FILE_MODEL_IDENTITY +from dify_graph.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType class StubScalars: @@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): """Transform args into parameters and files payloads.""" tool = _setup_transform_args_tool(monkeypatch) + build_file_from_stored_mapping = MagicMock( + side_effect=[ + SimpleNamespace( + transfer_method=FileTransferMethod.TOOL_FILE, + type=FileType.IMAGE, + reference="tool-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.LOCAL_FILE, + type=FileType.DOCUMENT, + reference="upload-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.REMOTE_URL, + type=FileType.DOCUMENT, + reference=None, + generate_url=lambda: "https://example.com/a.pdf", + ), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.build_file_from_stored_mapping", + build_file_from_stored_mapping, + ) params, files = tool._transform_args( { @@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + assert build_file_from_stored_mapping.call_count == 3 + assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list) def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a45..0c7e5899de 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -5,9 +5,10 @@ import pytest from pydantic import BaseModel from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segment_group import SegmentGroup from dify_graph.variables.segments import ( ArrayAnySegment, @@ -48,14 +49,28 @@ from dify_graph.variables.variables import ( ) +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool + + def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3..3ce4bb753b 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py index 22792eb5b3..6331edc819 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch import pytest -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from dify_graph.variables.variables import StringVariable @@ -23,6 +23,17 @@ class StubCoordinator: class TestGraphRuntimeState: + def test_execution_context_defaults_to_empty_context(self): + state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) + + with state.execution_context: + assert state.execution_context is not None + + state.execution_context = None + + with state.execution_context: + assert state.execution_context is not None + def test_property_getters_and_setters(self): # FIXME(-LAN-): Mock VariablePool if needed variable_pool = VariablePool() diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py index 6100ebede5..0515b9d93a 100644 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py @@ -126,7 +126,7 @@ class TestVariablePoolGetNotModifyVariableDictionary: def test_get_should_not_modify_variable_dictionary(self): pool = VariablePool.empty() pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id + assert len(pool.variable_dictionary) == 0 assert "start" not in pool.variable_dictionary pool = VariablePool.empty() diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py index 75de07bd8b..b61b8d59c3 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py @@ -5,11 +5,11 @@ from typing import Any import pytest from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import default_system_variables from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes import BuiltinNodeTypes from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -63,7 +63,7 @@ def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: ) graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], ), diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py index e94ad74eb0..a42b4e9c8e 100644 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py @@ -6,6 +6,7 @@ from dataclasses import dataclass import pytest +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType @@ -13,7 +14,6 @@ from dify_graph.graph import Graph from dify_graph.graph.validation import GraphValidationError from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -96,7 +96,7 @@ def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: invoke_from="service-api", call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}) runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) return factory, graph_config diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py index fc8133f5e1..f3db1f5e0c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py @@ -7,13 +7,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRecipientEntity, HumanInputFormRepository, ) +from dify_graph.nodes.human_input.enums import HumanInputFormStatus from libs.datetime_utils import naive_utc_now @@ -49,7 +49,7 @@ class _InMemoryFormEntity(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return self.token @property @@ -88,24 +88,24 @@ class InMemoryHumanInputFormRepository(HumanInputFormRepository): self._form_counter = 0 self.created_params: list[FormCreateParams] = [] self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: self.created_params.append(params) self._form_counter += 1 form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" + token = f"token-{form_id}" entity = _InMemoryFormEntity( form_id=form_id, rendered=params.rendered_content, token=token, ) self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity + self._forms_by_node_id[params.node_id] = entity return entity - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) # Convenience helpers for tests ------------------------------------- diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..4dba88932f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,9 +1,12 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError +from core.model_manager import ModelInstance from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph_engine.entities.commands import CommandType from dify_graph.graph_events.node import NodeRunSucceededEvent @@ -11,6 +14,16 @@ from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import NodeRunResult +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + def _build_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="execution-id", @@ -25,6 +38,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -32,8 +50,8 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -41,7 +59,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -53,8 +71,8 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -62,7 +80,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,7 +92,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +109,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,8 +131,8 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -141,7 +159,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -167,7 +185,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -175,5 +193,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index 765c4deba3..d1a7e2da51 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -3,8 +3,8 @@ import time from unittest.mock import MagicMock -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.pause_reason import SchedulingPause from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -73,9 +73,8 @@ def test_abort_command(): config=GraphEngineConfig(), ) - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) + # Queue an abort request before starting. + engine.request_abort("Test abort") # Run engine and collect events events = list(engine.run()) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py index d54f0be190..bc5d6c2c45 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py @@ -1,6 +1,7 @@ import time from collections.abc import Mapping +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.enums import NodeState from dify_graph.graph import Graph @@ -20,7 +21,6 @@ from dify_graph.nodes.llm.entities import ( from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params from .test_mock_config import MockConfig @@ -29,7 +29,7 @@ from .test_mock_nodes import MockLLMNode def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py index 538f53c603..9ffd1c7168 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py @@ -4,6 +4,9 @@ from collections.abc import Iterable from unittest import mock from unittest.mock import MagicMock +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -30,9 +33,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ def _build_branching_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -125,6 +126,7 @@ def _build_branching_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") @@ -246,7 +248,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -302,7 +304,7 @@ def test_human_input_llm_streaming_across_multiple_branches() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py index 36bba6deb6..dc9db6d725 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py @@ -3,6 +3,9 @@ import time from unittest import mock from unittest.mock import MagicMock +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunPausedEvent, @@ -29,9 +32,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ def _build_llm_human_llm_graph( if graph_runtime_state is None: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," ), user_inputs={}, @@ -121,6 +122,7 @@ def _build_llm_human_llm_graph( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") @@ -191,7 +193,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_create_repo.get_form.return_value = None mock_form_entity = MagicMock(spec=HumanInputFormEntity) mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" + mock_form_entity.submission_token = "test_web_app_token" mock_form_entity.recipients = [] mock_form_entity.rendered_content = "rendered" mock_form_entity.submitted = False @@ -260,7 +262,7 @@ def test_human_input_llm_streaming_order_across_pause() -> None: mock_get_repo = MagicMock(spec=HumanInputFormRepository) submitted_form = MagicMock(spec=HumanInputFormEntity) submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token + submitted_form.submission_token = mock_form_entity.submission_token submitted_form.recipients = [] submitted_form.rendered_content = mock_form_entity.rendered_content submitted_form.submitted = True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py index 8da179c15e..39759a7555 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py @@ -1,6 +1,7 @@ import time from unittest import mock +from core.workflow.system_variables import build_system_variables from dify_graph.graph import Graph from dify_graph.graph_events import ( GraphRunStartedEvent, @@ -26,7 +27,6 @@ from dify_graph.nodes.llm.entities import ( from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr ) variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py index 733fd53bc8..a8bc96117f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py @@ -14,6 +14,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -50,6 +51,7 @@ def test_loop_contains_answer(): NodeRunLoopStartedEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 1 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, @@ -60,6 +62,7 @@ def test_loop_contains_answer(): NodeRunLoopNextEvent, # Variable assigner NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, # 2 NodeRunStreamChunkEvent, # \n NodeRunSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py index 6ff2722f78..7c33a89fe0 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py @@ -7,6 +7,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -44,12 +45,16 @@ def test_loop_with_tool(): NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, NodeRunLoopNextEvent, # 2024 NodeRunStartedEvent, NodeRunSucceededEvent, NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, + NodeRunVariableUpdatedEvent, NodeRunSucceededEvent, # LOOP END NodeRunLoopSucceededEvent, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 93010eea54..450d11754e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -111,7 +111,7 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, + tool_file_manager_factory=self._bound_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) elif node_type in { diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py index 3e4247f33f..4fa210bee1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py @@ -2,7 +2,7 @@ Simple test to verify MockNodeFactory works with iteration nodes. """ -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9..b2740e91ba 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -20,16 +21,15 @@ from dify_graph.nodes.code import CodeNode from dify_graph.nodes.document_extractor import DocumentExtractorNode from dify_graph.nodes.http_request import HttpRequestNode from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.file_saver import LLMFileSaver +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol +from dify_graph.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol from dify_graph.nodes.question_classifier import QuestionClassifierNode from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) from dify_graph.nodes.tool import ToolNode +from dify_graph.template_rendering import Jinja2TemplateRenderer, TemplateRenderError if TYPE_CHECKING: from dify_graph.entities import GraphInitParams @@ -66,20 +66,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py index a8398e8f79..87284c6740 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py @@ -6,7 +6,7 @@ to ensure they work correctly with the TableTestRunner. """ from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.code.limits import CodeNodeLimits from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig @@ -60,7 +60,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -123,7 +123,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -187,7 +187,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -249,7 +249,7 @@ class TestMockTemplateTransformNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -318,7 +318,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -384,7 +384,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -458,7 +458,7 @@ class TestMockCodeNode: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -534,7 +534,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -579,7 +579,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) @@ -634,7 +634,7 @@ class TestMockNodeFactory: ) variable_pool = VariablePool( - system_variables={}, + system_variables=[], user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py index 5b35b3310a..d93c0a59b1 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py @@ -4,7 +4,7 @@ Simple test to validate the auto-mock system without external dependencies. import sys -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc7..08d30315ac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -23,13 +30,7 @@ from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -67,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -103,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -112,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py index 60167c0441..65317292ed 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -30,13 +37,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -59,7 +60,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -95,7 +96,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -115,7 +116,7 @@ class DelayedHumanInputNode(HumanInputNode): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -162,6 +163,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -171,6 +173,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), delay_seconds=0.2, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py index b954a4faac..ba9460e08e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py @@ -15,6 +15,7 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_system_variables from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig @@ -28,7 +29,6 @@ from dify_graph.graph_events import ( from dify_graph.node_events import NodeRunResult, StreamCompletedEvent from dify_graph.nodes.llm.node import LLMNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params from .test_table_runner import TableTestRunner @@ -98,7 +98,7 @@ def test_parallel_streaming_workflow(): ) # Create variable pool with system variables - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=init_params.workflow_id, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py index 7328ce443f..399c73b2ac 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py @@ -4,6 +4,13 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -31,13 +38,7 @@ from dify_graph.nodes.llm.entities import ( ) from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -60,7 +61,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -96,7 +97,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, form: HumanInputFormEntity) -> None: self._form = form - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: if node_id != "human_pause": return None return self._form @@ -107,7 +108,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -201,6 +202,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py index 15a7de3c52..a528412682 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py @@ -3,6 +3,12 @@ import time from typing import Any from unittest.mock import MagicMock +from core.repositories.human_input_repository import ( + HumanInputFormEntity, + HumanInputFormRepository, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from dify_graph.entities.workflow_start_reason import WorkflowStartReason from dify_graph.graph import Graph from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel @@ -22,19 +28,14 @@ from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -50,7 +51,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -65,7 +66,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -112,6 +113,7 @@ def _build_human_input_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py index 4f1741d4fb..f8aaea6424 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py @@ -4,6 +4,7 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, + NodeRunVariableUpdatedEvent, ) from .test_mock_config import MockConfigBuilder @@ -33,6 +34,7 @@ def test_streaming_conversation_variables(): NodeRunSucceededEvent, # Variable Assigner node NodeRunStartedEvent, + NodeRunVariableUpdatedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent, # ANSWER node diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8..75e0484225 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,28 +12,28 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.tools.utils.yaml_utils import _load_yaml_file from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.graph import Graph from dify_graph.graph_engine import GraphEngine, GraphEngineConfig from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer from dify_graph.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ( ArrayNumberVariable, ArrayObjectVariable, @@ -60,20 +60,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -81,13 +89,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -206,14 +212,15 @@ class WorkflowRunner: call_depth=0, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +249,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +272,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py new file mode 100644 index 0000000000..d8ec3c7037 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_update_events.py @@ -0,0 +1,129 @@ +import time +import uuid +from uuid import uuid4 + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool +from dify_graph.entities import GraphInitParams +from dify_graph.graph import Graph +from dify_graph.graph_engine import GraphEngine, GraphEngineConfig +from dify_graph.graph_engine.command_channels import InMemoryChannel +from dify_graph.graph_engine.layers.base import GraphEngineLayer +from dify_graph.graph_events import NodeRunVariableUpdatedEvent +from dify_graph.runtime import GraphRuntimeState, VariablePool +from dify_graph.variables import StringVariable + +DEFAULT_NODE_ID = "node_id" + + +class CaptureVariableUpdateLayer(GraphEngineLayer): + def __init__(self) -> None: + super().__init__() + self.events: list[NodeRunVariableUpdatedEvent] = [] + self.observed_values: list[object | None] = [] + + def on_graph_start(self) -> None: + pass + + def on_event(self, event) -> None: + if not isinstance(event, NodeRunVariableUpdatedEvent): + return + + current_value = self.graph_runtime_state.variable_pool.get(event.variable.selector) + self.events.append(event) + self.observed_values.append(None if current_value is None else current_value.value) + + def on_graph_end(self, error: Exception | None) -> None: + pass + + +def test_graph_engine_applies_variable_updates_before_notifying_layers(): + graph_config = { + "edges": [ + { + "id": "start-source-assigner-target", + "source": "start", + "target": "assigner", + }, + ], + "nodes": [ + {"data": {"type": "start", "title": "Start"}, "id": "start"}, + { + "data": { + "type": "assigner", + "title": "Variable Assigner", + "assigned_variable_selector": ["conversation", "test_conversation_variable"], + "write_mode": "over-write", + "input_variable_selector": ["node_id", "test_string_variable"], + }, + "id": "assigner", + }, + ], + } + + init_params = GraphInitParams( + workflow_id="1", + graph_config=graph_config, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "1", + "app_id": "1", + "user_id": "1", + "user_from": UserFrom.ACCOUNT, + "invoke_from": InvokeFrom.DEBUGGER, + } + }, + call_depth=0, + ) + + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=str(uuid.uuid4())), + conversation_variables=[ + StringVariable( + id=str(uuid4()), + name="test_conversation_variable", + value="the first value", + ) + ], + ), + ) + variable_pool.add( + [DEFAULT_NODE_ID, "test_string_variable"], + StringVariable( + id=str(uuid4()), + name="test_string_variable", + value="the second value", + ), + ) + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) + graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") + + engine = GraphEngine( + workflow_id="workflow-id", + graph=graph, + graph_runtime_state=graph_runtime_state, + command_channel=InMemoryChannel(), + config=GraphEngineConfig(), + ) + capture_layer = CaptureVariableUpdateLayer() + engine.layer(capture_layer) + + events = list(engine.run()) + + update_events = [event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)] + assert len(update_events) == 1 + assert update_events[0].variable.value == "the second value" + + current_value = graph_runtime_state.variable_pool.get(["conversation", "test_conversation_variable"]) + assert current_value is not None + assert current_value.value == "the second value" + + assert len(capture_layer.events) == 1 + assert capture_layer.observed_values == ["the second value"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py index bc00b49fba..2b10ed0407 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py @@ -12,7 +12,8 @@ from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) + mock_datetime = mocker.patch("dify_graph.graph_engine.worker.datetime") + mock_datetime.now.return_value = fixed_time.replace(tzinfo=UTC) worker = Worker( ready_queue=InMemoryReadyQueue(), @@ -75,7 +76,8 @@ def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("dify_graph.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] @@ -135,7 +137,8 @@ def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_time worker._event_queue.put.side_effect = put_side_effect - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): + with patch("dify_graph.graph_engine.worker.datetime") as mock_datetime: + mock_datetime.now.return_value = failure_time.replace(tzinfo=UTC) worker.run() fallback_event = captured_events[-1] diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 0000000000..4c2fdabd0b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,33 @@ +from unittest.mock import patch + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer +from dify_graph.enums import BuiltinNodeTypes + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 0000000000..fea5e24cf6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,49 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport +from dify_graph.model_runtime.entities.model_entities import ModelType + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2..3c7017cd54 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -4,11 +4,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.graph import Graph from dify_graph.nodes.answer.answer_node import AnswerNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +48,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3..ab536bbf4b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,5 +1,5 @@ +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea7195417..93d876fb26 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -2,6 +2,7 @@ import pytest from configs import dify_config from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import ( BodyData, @@ -14,7 +15,6 @@ from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout from dify_graph.nodes.http_request.exc import AuthorizationConfigError from dify_graph.nodes.http_request.executor import Executor from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94..afeb78fb2c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -7,12 +7,13 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file.file_manager import file_manager from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -109,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65..1a731726b7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,4 +1,4 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients from dify_graph.runtime import VariablePool diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c0..b62d42f2a8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -8,35 +8,38 @@ from unittest.mock import MagicMock import pytest from pydantic import ValidationError -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.node_events import PauseRequestedEvent -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.human_input.entities import ( +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import HumanInputFormRepository +from core.workflow.human_input_compat import ( + DeliveryMethodType, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, + EmailRecipientType, ExternalRecipient, - FormInput, - FormInputDefault, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, _WebAppDeliveryConfig, ) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from dify_graph.entities import GraphInitParams +from dify_graph.node_events import PauseRequestedEvent +from dify_graph.node_events.node import StreamCompletedEvent +from dify_graph.nodes.human_input.entities import ( + FormInput, + FormInputDefault, + HumanInputNodeData, + UserAction, +) from dify_graph.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit, ) from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository @@ -54,9 +57,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -193,7 +196,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -212,7 +215,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -261,10 +264,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -273,37 +276,46 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + recipients = EmailRecipients(whole_workspace=True, items=[]) + assert recipients.include_bound_group is True + assert recipients.items == [] + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -353,17 +365,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -378,7 +392,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -416,28 +430,96 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") + + def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): + variable_pool = VariablePool( + system_variables=build_system_variables( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-4", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "end-user-1", + "user_from": "end-user", + "invoke_from": "web-app", + } + }, + call_depth=0, + ) + + config = { + "id": "human", + "data": { + "type": "human-input", + "title": "Human Input", + "form_content": "Provide your name", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "delivery_methods": [{"enabled": True, "type": "webapp", "config": {}}], + }, + } + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-4", + rendered_content="Provide your name", + submission_token="token", + recipients=[], + submitted=False, + ) + + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + runtime=runtime, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + params = mock_repo.create_form.call_args.args[0] + assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -472,7 +554,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -489,17 +571,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -511,11 +595,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -552,7 +636,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -591,12 +675,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158d..26de1833da 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,8 +1,10 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.enums import BuiltinNodeTypes from dify_graph.graph_events import ( NodeRunHumanInputFormFilledEvent, @@ -12,7 +14,6 @@ from dify_graph.graph_events import ( from dify_graph.nodes.human_input.enums import HumanInputFormStatus from dify_graph.nodes.human_input.human_input_node import HumanInputNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from libs.datetime_utils import naive_utc_now @@ -25,7 +26,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -85,11 +86,12 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -149,6 +151,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py new file mode 100644 index 0000000000..969a21ed7f --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_abort_propagation.py @@ -0,0 +1,201 @@ +from threading import Event +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph_events import GraphRunAbortedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import IterationFailedEvent, IterationStartedEvent, StreamCompletedEvent +from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData +from dify_graph.nodes.iteration.exc import ChildGraphAbortedError +from dify_graph.nodes.iteration.iteration_node import IterationNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage + + +class _AbortOnRequestGraphEngine: + def __init__(self, *, index: int, total_tokens: int) -> None: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + self.started = Event() + self.abort_requested = Event() + self.finished = Event() + self.abort_reason: str | None = None + self.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ) + + def request_abort(self, reason: str | None = None) -> None: + self.abort_reason = reason + self.abort_requested.set() + + def run(self): + self.started.set() + assert self.abort_requested.wait(1), "parallel sibling never received an abort request" + self.finished.set() + yield GraphRunAbortedEvent(reason=self.abort_reason) + + +def _build_immediate_abort_graph_engine( + *, + index: int, + total_tokens: int, + wait_before_abort: Event | None = None, +) -> SimpleNamespace: + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], index) + + started = Event() + finished = Event() + + def run(): + started.set() + if wait_before_abort is not None: + assert wait_before_abort.wait(1), "parallel sibling never started" + finished.set() + yield GraphRunAbortedEvent(reason="quota exceeded") + + return SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=variable_pool, + llm_usage=_usage_with_tokens(total_tokens), + ), + run=run, + request_abort=lambda reason=None: None, + started=started, + finished=finished, + ) + + +def _build_iteration_node( + *, + error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED, + is_parallel: bool = False, +) -> IterationNode: + node = IterationNode.__new__(IterationNode) + node._node_id = "iteration-node" + node._node_data = IterationNodeData( + title="Iteration", + iterator_selector=["start", "items"], + output_selector=["iteration-node", "output"], + start_node_id="child-start", + is_parallel=is_parallel, + parallel_nums=2, + error_handle_mode=error_handle_mode, + ) + + variable_pool = build_test_variable_pool() + variable_pool.add(["start", "items"], ["first", "second"]) + node.graph_runtime_state = SimpleNamespace( + variable_pool=variable_pool, + llm_usage=LLMUsage.empty_usage(), + ) + return node + + +def test_run_single_iter_raises_child_graph_aborted_error_on_abort_event() -> None: + node = _build_iteration_node() + variable_pool = build_test_variable_pool() + variable_pool.add(["iteration-node", "index"], 0) + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(ChildGraphAbortedError, match="quota exceeded"): + list( + node._run_single_iter( + variable_pool=variable_pool, + outputs=[], + graph_engine=graph_engine, + ) + ) + + +def test_iteration_run_fails_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + node._create_graph_engine.assert_called_once() + node._run_single_iter.assert_called_once() + + +def test_iteration_run_merges_child_usage_before_failing_on_sequential_child_abort() -> None: + node = _build_iteration_node(error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR) + graph_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=_usage_with_tokens(7), + ) + ) + node._create_graph_engine = MagicMock(return_value=graph_engine) + node._run_single_iter = MagicMock(side_effect=ChildGraphAbortedError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 + + +@pytest.mark.parametrize( + "error_handle_mode", + [ + ErrorHandleMode.CONTINUE_ON_ERROR, + ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, + ], +) +def test_iteration_run_fails_on_parallel_child_abort_regardless_of_error_mode( + error_handle_mode: ErrorHandleMode, +) -> None: + node = _build_iteration_node( + error_handle_mode=error_handle_mode, + is_parallel=True, + ) + blocking_engine = _AbortOnRequestGraphEngine(index=1, total_tokens=5) + aborting_engine = _build_immediate_abort_graph_engine( + index=0, + total_tokens=3, + wait_before_abort=blocking_engine.started, + ) + node._create_graph_engine = MagicMock( + side_effect=lambda index, item: {0: aborting_engine, 1: blocking_engine}[index] + ) + + events = list(node._run()) + + assert isinstance(events[0], IterationStartedEvent) + assert isinstance(events[-2], IterationFailedEvent) + assert events[-2].error == "quota exceeded" + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[-1].node_run_result.error == "quota exceeded" + assert events[-1].node_run_result.llm_usage.total_tokens == 8 + assert node.graph_runtime_state.llm_usage.total_tokens == 8 + assert blocking_engine.started.is_set() + assert blocking_engine.abort_requested.is_set() + assert blocking_engine.finished.is_set() + assert blocking_engine.abort_reason == "quota exceeded" diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f..1a6938b4c8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,8 +1,9 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest +from core.workflow.system_variables import default_system_variables from dify_graph.entities import GraphInitParams from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError from dify_graph.nodes.iteration.iteration_node import IterationNode @@ -12,7 +13,6 @@ from dify_graph.runtime import ( GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -22,17 +22,16 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py index 8660449032..9931eaf39a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py @@ -1,5 +1,4 @@ import time -from contextlib import nullcontext from datetime import UTC, datetime import pytest @@ -21,11 +20,17 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: parallel_nums=2, error_handle_mode=ErrorHandleMode.TERMINATED, ) - node._capture_execution_context = lambda: nullcontext() - node._sync_conversation_variables_from_snapshot = lambda snapshot: None node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): + def fake_execute_tracked_iteration_parallel( + *, + index: int, + item: object, + started_child_engines: dict[int, object], + started_child_engines_lock: object, + ): + _ = started_child_engines + _ = started_child_engines_lock return ( 0.1 + (index * 0.1), [ @@ -37,11 +42,10 @@ def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: ), ], f"output-{item}", - {}, LLMUsage.empty_usage(), ) - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel + node._execute_tracked_iteration_parallel = fake_execute_tracked_iteration_parallel outputs: list[object] = [] iter_run_map: dict[str, float] = {} diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index feb560bbc3..15a7807bda 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -15,9 +15,9 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus +from core.workflow.system_variables import SystemVariableKey, build_system_variables +from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -41,7 +41,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b2..d61cb222b9 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -16,10 +16,10 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import StringSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1..b944c6e785 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -2,8 +2,8 @@ from unittest.mock import MagicMock import pytest +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus from dify_graph.nodes.list_operator.node import ListOperatorNode from dify_graph.runtime import GraphRuntimeState diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py index b0f0fd428b..ba433b27b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py @@ -6,17 +6,14 @@ from unittest.mock import MagicMock import httpx import pytest -from core.helper import ssrf_proxy -from core.tools import signature -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import FileTransferMethod, FileType, models +from dify_graph.file import FileTransferMethod, FileType from dify_graph.nodes.llm.file_saver import ( FileSaverImpl, _extract_content_type_and_extension, _get_extension, _validate_extension_override, ) -from models import ToolFile +from dify_graph.nodes.protocols import ToolFileManagerProtocol _PNG_DATA = b"\x89PNG\r\n\x1a\n" @@ -27,58 +24,45 @@ def _gen_id(): class TestFileSaverImpl: def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - user_id = _gen_id() - tenant_id = _gen_id() file_type = FileType.IMAGE mime_type = "image/png" - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) + mock_tool_file = MagicMock() mock_tool_file.id = _gen_id() - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - + mock_tool_file.name = f"{_gen_id()}.png" + mock_tool_file.file_key = "test-file-key" + mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManagerProtocol) mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url + file_reference = MagicMock() + file_reference_factory = MagicMock() + file_reference_factory.build_from_mapping.return_value = file_reference http_client = MagicMock() - storage_file_manager = FileSaverImpl( - user_id=user_id, - tenant_id=tenant_id, + file_saver = FileSaverImpl( + tool_file_manager=mocked_tool_file_manager, + file_reference_factory=file_reference_factory, http_client=http_client, ) - file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file.tenant_id == tenant_id - assert file.type == file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == ".png" - assert file.mime_type == mime_type - assert file.size == len(_PNG_DATA) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url + file = file_saver.save_binary_string(_PNG_DATA, mime_type, file_type) + assert file is file_reference mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, file_binary=_PNG_DATA, mimetype=mime_type, ) - mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) + file_reference_factory.build_from_mapping.assert_called_once_with( + mapping={ + "type": file_type, + "transfer_method": FileTransferMethod.TOOL_FILE, + "filename": mock_tool_file.name, + "extension": ".png", + "mime_type": mime_type, + "size": len(_PNG_DATA), + "tool_file_id": mock_tool_file.id, + "related_id": mock_tool_file.id, + "storage_key": mock_tool_file.file_key, + } + ) def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" @@ -91,8 +75,8 @@ class TestFileSaverImpl: http_client.get.return_value = mock_response file_saver = FileSaverImpl( - user_id=_gen_id(), - tenant_id=_gen_id(), + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), http_client=http_client, ) @@ -104,8 +88,6 @@ class TestFileSaverImpl: def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): _TEST_URL = "https://example.com/image.png" mime_type = "image/png" - user_id = _gen_id() - tenant_id = _gen_id() mock_request = httpx.Request("GET", _TEST_URL) mock_response = httpx.Response( @@ -117,21 +99,13 @@ class TestFileSaverImpl: http_client = MagicMock() http_client.get.return_value = mock_response - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), + file_saver = FileSaverImpl( + tool_file_manager=MagicMock(), + file_reference_factory=MagicMock(), + http_client=http_client, ) - mock_tool_file.id = _gen_id() - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) + expected_file = MagicMock() + mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=expected_file) monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) @@ -141,7 +115,7 @@ class TestFileSaverImpl: FileType.IMAGE, extension_override=".png", ) - assert file == mock_tool_file + assert file is expected_file def test_validate_extension_override(): diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index acecbf4944..1d51050d19 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -3,16 +3,85 @@ from unittest import mock import pytest from core.model_manager import ModelInstance +from dify_graph.file import FileTransferMethod, FileType +from dify_graph.file.models import File from dify_graph.model_runtime.entities import ( ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage +from dify_graph.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage -from dify_graph.nodes.llm.exc import NoPromptFoundError +from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from dify_graph.nodes.llm.exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) from dify_graph.runtime import VariablePool +from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label={"en_US": "GPT-3.5 Turbo"}, + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_model_instance(*, model_schema: AIModelEntity | None = None) -> mock.MagicMock: + model_instance = mock.MagicMock(spec=ModelInstance) + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.get_model_schema.return_value = model_schema or _build_model_schema(features=[]) + model_instance.get_llm_num_tokens.return_value = 0 + return model_instance + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) @pytest.fixture @@ -270,3 +339,700 @@ def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_ ] ) ] + + +def test_fetch_model_schema_raises_when_model_schema_is_missing(): + model_instance = _build_model_instance() + model_instance.get_model_schema.return_value = None + + with pytest.raises(ValueError, match="Model schema not found for gpt-3.5-turbo"): + llm_utils.fetch_model_schema(model_instance=model_instance) + + +def test_fetch_files_supports_known_segments_and_rejects_invalid_types(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + variable_pool = VariablePool.empty() + variable_pool.add(["input", "file"], file) + variable_pool.add(["input", "files"], ArrayFileSegment(value=[file])) + variable_pool.add(["input", "none"], NoneSegment()) + variable_pool.add(["input", "empty"], ArrayAnySegment(value=[])) + variable_pool.add(["input", "invalid"], {"a": 1}) + + assert llm_utils.fetch_files(variable_pool, ["input", "file"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "files"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "none"]) == [] + assert llm_utils.fetch_files(variable_pool, ["input", "empty"]) == [] + + with pytest.raises(InvalidVariableTypeError, match="Invalid variable type"): + llm_utils.fetch_files(variable_pool, ["input", "invalid"]) + + +def test_fetch_files_returns_empty_for_missing_variable(): + assert llm_utils.fetch_files(VariablePool.empty(), ["input", "missing"]) == [] + + +def test_convert_history_messages_to_text_skips_system_messages_and_formats_images(): + history_text = llm_utils.convert_history_messages_to_text( + history_messages=[ + SystemPromptMessage(content="skip"), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="Answer"), + ], + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert history_text == "Human: Question\n[image]\nAssistant: Answer" + + +def test_fetch_memory_text_uses_prompt_memory_interface(): + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=321, + message_limit=2, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert memory_text == "Human: Question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_handle_list_messages_renders_jinja2_messages(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=renderer, + ) + + assert prompt_messages == [SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")])] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_handle_list_messages_splits_text_and_file_content(): + variable_pool = VariablePool.empty() + image_file = _build_image_file( + file_id="image-file", + related_id="image-related", + remote_url="https://example.com/file.png", + ) + variable_pool.add(["input", "image"], image_file) + + with mock.patch( + "dify_graph.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ) as mock_to_prompt: + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="Analyze {{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Analyze ")]), + UserPromptMessage( + content=[ + ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ] + ), + ] + mock_to_prompt.assert_called_once() + + +def test_handle_list_messages_supports_array_file_segments(): + variable_pool = VariablePool.empty() + first_file = _build_image_file(file_id="first", related_id="first-related", remote_url="https://example.com/1.png") + second_file = _build_image_file( + file_id="second", + related_id="second-related", + remote_url="https://example.com/2.png", + ) + variable_pool.add(["input", "images"], ArrayFileSegment(value=[first_file, second_file])) + + first_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/1.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + second_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/2.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + with mock.patch( + "dify_graph.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=[first_prompt, second_prompt], + ): + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [UserPromptMessage(content=[first_prompt, second_prompt])] + + +def test_render_jinja2_message_handles_empty_template_success_and_missing_renderer(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + assert ( + llm_utils.render_jinja2_message( + template="", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + == "" + ) + + with pytest.raises(ValueError, match="template_renderer is required"): + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + assert ( + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=renderer, + ) + == "Hello Dify" + ) + + +def test_handle_completion_template_supports_basic_and_jinja2_templates(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + basic_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize {{#context#}}", + edition_type="basic", + ), + context="the docs", + jinja2_variables=[], + variable_pool=variable_pool, + ) + jinja_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="Hello {{ name }}", + edition_type="jinja2", + ), + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + template_renderer=renderer, + ) + + assert basic_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Summarize the docs")]), + ] + assert jinja_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + + +def test_combine_message_content_with_role_handles_all_supported_roles(): + contents = [TextPromptMessageContent(data="hello")] + + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.USER) == ( + UserPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.ASSISTANT) == ( + AssistantPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.SYSTEM) == ( + SystemPromptMessage(content=contents) + ) + + with pytest.raises(NotImplementedError, match="Role custom is not supported"): + llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] + + +def test_calculate_rest_token_uses_context_size_and_template_alias(): + model_instance = _build_model_instance( + model_schema=_build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="output_limit", + use_template="max_tokens", + label={"en_US": "Output Limit"}, + type=ParameterType.INT, + ) + ], + ) + ) + model_instance.parameters = {"max_tokens": 512} + model_instance.get_llm_num_tokens.return_value = 256 + + assert ( + llm_utils.calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 3328 + ) + + +def test_handle_memory_chat_mode_returns_empty_without_memory_and_uses_window_when_present(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + assert ( + llm_utils.handle_memory_chat_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == [] + ) + + with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=123) as mock_rest: + messages = llm_utils.handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + assert messages == [UserPromptMessage(content="Question")] + mock_rest.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=123, message_limit=2) + + +def test_handle_memory_completion_mode_validates_role_prefix_and_formats_history(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="Question"), + AssistantPromptMessage(content="Answer"), + ] + + assert ( + llm_utils.handle_memory_completion_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == "" + ) + + with ( + mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=456), + pytest.raises(MemoryRolePrefixRequiredError, match="Memory role prefix is required"), + ): + llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=456): + history_text = llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ), + model_instance=model_instance, + ) + + assert history_text == "Human: Question\nAssistant: Answer" + memory.get_history_prompt_messages.assert_called_with(max_token_limit=456, message_limit=None) + + +def test_append_file_prompts_merges_with_existing_user_content_or_appends_new_message(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + file_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + prompt_messages = [UserPromptMessage(content=[TextPromptMessageContent(data="Question")])] + + with mock.patch( + "dify_graph.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[file_prompt, TextPromptMessageContent(data="Question")]), + ] + + prompt_messages = [SystemPromptMessage(content="System prompt")] + with mock.patch( + "dify_graph.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages[-1] == UserPromptMessage(content=[file_prompt]) + + +def test_fetch_prompt_messages_chat_mode_includes_query_memory_and_supported_files(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.VISION])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history")] + sys_file = _build_image_file(file_id="sys", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + file_prompts = [ + ImagePromptMessageContent( + format="png", + url="https://example.com/sys.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + format="png", + url="https://example.com/context.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch( + "dify_graph.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=file_prompts, + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history") + assert prompt_messages[2] == UserPromptMessage( + content=[ + file_prompts[1], + file_prompts[0], + TextPromptMessageContent(data="current question"), + ] + ) + + +def test_fetch_prompt_messages_completion_mode_updates_list_content_with_histories_and_query(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="another question"), + AssistantPromptMessage(content="another answer"), + ] + + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header", + edition_type="basic", + ), + stop=None, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [ + UserPromptMessage(content="latest question\nHuman: another question\nAssistant: another answer\nPrompt header") + ] + + +def test_fetch_prompt_messages_filters_content_unsupported_by_model_features(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.DOCUMENT])) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_list_messages", + return_value=[ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + ], + ), + mock.patch("dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[]), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=("END",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("END",) + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_completion_mode_supports_string_content_and_invalid_template_type(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix #histories# and #sys.query#")], + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=("HALT",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [UserPromptMessage(content="Prefix history text and latest question")] + + with pytest.raises(TemplateTypeNotSupportError): + llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=object(), # type: ignore[arg-type] + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + invalid_prompt = mock.MagicMock() + invalid_prompt.content = object() + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_completion_template", + return_value=[invalid_prompt], + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + pytest.raises(ValueError, match="Invalid prompt content type"), + ): + llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + with ( + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix only")], + ), + mock.patch( + "dify_graph.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content="history text\nPrefix only")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af1..8b0c2dfa38 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -5,39 +5,79 @@ from unittest import mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from dify_graph.entities import GraphInitParams from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) from dify_graph.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from dify_graph.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from dify_graph.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from dify_graph.nodes.base.entities import VariableSelector from dify_graph.nodes.llm import llm_utils from dify_graph.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptConfig, VisionConfig, VisionConfigOptions, ) +from dify_graph.nodes.llm.exc import ( + InvalidContextStructureError, + LLMNodeError, + NoPromptFoundError, + VariableNotFoundError, +) from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.node import ( + LLMNode, + _calculate_rest_token, + _handle_completion_template, + _handle_memory_chat_mode, + _handle_memory_completion_mode, + _render_jinja2_message, +) +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory +from dify_graph.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.template_rendering import TemplateRenderError from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -55,6 +95,62 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -91,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -107,7 +203,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -120,9 +216,9 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +228,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +280,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) - mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +312,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +358,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,18 +391,18 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -284,7 +420,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -293,7 +428,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -343,7 +477,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -448,7 +581,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -524,7 +656,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -569,7 +700,7 @@ def test_fetch_files_with_non_existent_variable(): def test_handle_list_messages_basic(llm_node): messages = [ LLMNodeChatModelMessage( - text="Hello, {#context#}", + text="Hello, {{#context#}}", role=PromptMessageRole.USER, edition_type="basic", ) @@ -592,32 +723,414 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" +def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node): messages = [ LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", + text="Answer user's question with the following context:\n\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", ) ] + context = "## Overview\nSends a JSON request." result = llm_node.handle_list_messages( messages=messages, - context=None, + context=context, jinja2_variables=[], variable_pool=llm_node.graph_runtime_state.variable_pool, vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, ) - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert result[0].content == [ + TextPromptMessageContent( + data="Answer user's question with the following context:\n\n## Overview\nSends a JSON request." + ) + ] + + +def test_handle_list_messages_renders_jinja2_messages(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_node.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + jinja2_template_renderer=renderer, ) + assert prompt_messages == [ + SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_transform_chat_messages_prefers_jinja2_text(llm_node): + completion_template = LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="completion prompt", + edition_type="jinja2", + ) + chat_messages = [ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="chat prompt", + role=PromptMessageRole.USER, + edition_type="jinja2", + ), + LLMNodeChatModelMessage( + text="keep original", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + ] + + transformed_completion = llm_node._transform_chat_messages(completion_template) + transformed_messages = llm_node._transform_chat_messages(chat_messages) + + assert transformed_completion.text == "completion prompt" + assert transformed_messages[0].text == "chat prompt" + assert transformed_messages[1].text == "keep original" + + +def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node): + llm_node.graph_runtime_state.variable_pool.add( + ["input", "items"], + ["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3], + ) + llm_node.graph_runtime_state.variable_pool.add( + ["input", "context_doc"], + {"metadata": {"_source": "knowledge"}, "content": "context body"}, + ) + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[ + VariableSelector(variable="items", value_selector=["input", "items"]), + VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]), + VariableSelector(variable="payload", value_selector=["input", "payload"]), + ] + ) + } + ) + + assert llm_node._fetch_jinja_inputs(node_data) == { + "items": "alpha\nbeta\n3", + "context_doc": "context body", + "payload": '{"a": 1}', + } + + +def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node): + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])] + ) + } + ) + + with pytest.raises(VariableNotFoundError, match="Variable missing not found"): + llm_node._fetch_jinja_inputs(node_data) + + +def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_template": [ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}} with {{#input.payload#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + "memory": MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#input.name#}}", + ), + } + ) + + assert llm_node._fetch_inputs(node_data) == { + "#input.name#": "Dify", + "#input.payload#": {"active": True}, + } + + +def test_fetch_context_emits_string_context_event(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context") + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert events == [ + RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]), + ] + + +def test_fetch_context_collects_retriever_resources_and_attachments(llm_node): + attachment = _build_image_file( + file_id="attachment", + related_id="attachment-related", + remote_url="https://example.com/attachment.png", + ) + llm_node._retriever_attachment_loader = mock.MagicMock() + llm_node._retriever_attachment_loader.load.return_value = [attachment] + + llm_node.graph_runtime_state.variable_pool.add( + ["context", "value"], + [ + { + "content": "chunk body", + "summary": "chunk summary", + "files": [{"id": "file-1"}], + "metadata": { + "_source": "knowledge", + "dataset_id": "dataset-1", + "segment_id": "segment-1", + "segment_word_count": 12, + }, + }, + "tail text", + ], + ) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert len(events) == 1 + event = events[0] + assert event.context == "chunk summary\nchunk body\ntail text" + assert event.context_files == [attachment] + assert event.retriever_resources == [ + { + "position": None, + "dataset_id": "dataset-1", + "dataset_name": None, + "document_id": None, + "document_name": None, + "data_source_type": None, + "segment_id": "segment-1", + "retriever_from": None, + "score": None, + "hit_count": None, + "word_count": 12, + "segment_position": None, + "index_node_hash": None, + "content": "chunk body", + "page": None, + "doc_metadata": None, + "files": [{"id": "file-1"}], + "summary": "chunk summary", + } + ] + llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1") + + +def test_fetch_context_rejects_invalid_context_structure(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}]) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + with pytest.raises(InvalidContextStructureError, match="Invalid context structure"): + list(llm_node._fetch_context(node_data)) + + +def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")] + + sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context-file", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + + prompt_content_side_effect = [ + ImagePromptMessageContent( + url="https://example.com/sys.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + url="https://example.com/context.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch("dify_graph.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt: + mock_to_prompt.side_effect = prompt_content_side_effect + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + ), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history answer") + assert isinstance(prompt_messages[2], UserPromptMessage) + assert isinstance(prompt_messages[2].content, list) + assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent) + assert prompt_messages[2].content[0].url == "https://example.com/context.png" + assert prompt_messages[2].content[1].url == "https://example.com/sys.png" + assert prompt_messages[2].content[2].data == "current question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None) + + +def test_fetch_prompt_messages_completion_mode_injects_histories_and_query(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + +def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + variable_pool = VariablePool.empty() + variable_pool.add( + ["input", "image"], + _build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"), + ) + + with ( + mock.patch( + "dify_graph.nodes.llm.node.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + url="https://example.com/file.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + pytest.raises(NoPromptFoundError, match="No prompt found"), + ): + LLMNode.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + ) + + +def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node): + prompt_messages = _handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize the following context:\n{{#context#}}", + edition_type="basic", + ), + context="## Overview\nSends a JSON request.", + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_template_renderer=None, + ) + + assert prompt_messages == [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Summarize the following context:\n## Overview\nSends a JSON request.") + ] + ) + ] + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) @@ -635,15 +1148,15 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +1172,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -672,9 +1185,9 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -690,7 +1203,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -721,7 +1233,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -776,7 +1287,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", @@ -906,3 +1416,322 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +@pytest.mark.parametrize( + ("structured_output_enabled", "structured_output"), + [ + (False, None), + (True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}), + ], +) +def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output): + model_instance = _build_prepared_llm_mock() + prompt_messages = [UserPromptMessage(content="hello")] + file_saver = mock.MagicMock(spec=LLMFileSaver) + + model_instance.invoke_llm.return_value = iter([]) + model_instance.invoke_llm_with_structured_output.return_value = iter([]) + + with ( + mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle, + mock.patch("dify_graph.nodes.llm.node.time.perf_counter", return_value=10.0), + ): + result = list( + LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=("STOP",), + structured_output_enabled=structured_output_enabled, + structured_output=structured_output, + file_saver=file_saver, + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + reasoning_format="separated", + ) + ) + + assert result == ["handled"] + if structured_output_enabled: + model_instance.invoke_llm_with_structured_output.assert_called_once_with( + prompt_messages=prompt_messages, + json_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, + model_parameters={}, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm.assert_not_called() + else: + model_instance.invoke_llm.assert_called_once_with( + prompt_messages=prompt_messages, + model_parameters={}, + tools=None, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm_with_structured_output.assert_not_called() + + assert mock_handle.call_args.kwargs["request_start_time"] == 10.0 + + +def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output(): + usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16}) + first_chunk = LLMResultChunkWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="plan")]), + ), + structured_output={"draft": True}, + ) + final_chunk = LLMResultChunk( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]), + usage=usage, + finish_reason="stop", + ), + ) + + with mock.patch("dify_graph.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]): + events = list( + LLMNode.handle_invoke_result( + invoke_result=iter([first_chunk, final_chunk]), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=_build_prepared_llm_mock(), + reasoning_format="separated", + request_start_time=1.0, + ) + ) + + assert events[0] == first_chunk + assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="plan", is_final=False) + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + completed = events[3] + assert isinstance(completed, ModelInvokeCompletedEvent) + assert completed.text == "answer" + assert completed.reasoning_content == "plan" + assert completed.structured_output == {"draft": True} + assert completed.finish_reason == "stop" + assert completed.usage.total_tokens == 16 + assert completed.usage.latency == 4.0 + assert completed.usage.time_to_first_token == 1.0 + assert completed.usage.time_to_generate == 3.0 + + +def test_handle_invoke_result_wraps_structured_output_parse_errors(): + model_instance = _build_prepared_llm_mock() + model_instance.is_structured_output_parse_error.return_value = True + + def broken_stream(): + raise ValueError("bad json") + yield + + with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"): + list( + LLMNode.handle_invoke_result( + invoke_result=broken_stream(), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=model_instance, + ) + ) + + +def test_handle_blocking_result_extracts_reasoning_and_structured_output(): + invoke_result = LLMResultWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + message=AssistantPromptMessage(content="reasoningfinal answer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "final answer"}, + ) + + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "final answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "final answer"} + assert event.usage.latency == 1.234 + + +def test_fetch_structured_output_schema_validates_payload(): + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == { + "type": "object" + } + + with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={}) + + with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]}) + + +def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors(): + node_data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ), + ], + prompt_config=PromptConfig( + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])] + ), + memory=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#sys.query#}}", + ), + context=ContextConfig(enabled=True, variable_selector=["context", "value"]), + vision=VisionConfig(enabled=True), + ) + + mapping = LLMNode._extract_variable_selector_to_variable_mapping( + graph_config={}, + node_id="llm-1", + node_data=node_data, + ) + + assert mapping == { + "llm-1.#input.name#": ["input", "name"], + "llm-1.#sys.query#": ["sys", "query"], + "llm-1.#context#": ["context", "value"], + "llm-1.#files#": ["sys", "files"], + "llm-1.name": ["input", "name"], + } + + +def test_render_jinja2_message_requires_renderer_and_passes_inputs(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + with pytest.raises( + TemplateRenderError, + match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts", + ): + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + assert ( + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=renderer, + ) + == "Hello Dify" + ) + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_calculate_rest_token_uses_context_size_and_max_tokens(): + model_instance = _build_prepared_llm_mock() + model_instance.parameters = {"max_tokens": 512} + model_instance.get_model_schema.return_value = _build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + ) + ], + ) + model_instance.get_llm_num_tokens.return_value = 1000 + + assert ( + _calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 2584 + ) + + +def test_handle_memory_chat_mode_uses_calculated_token_budget(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + history = [UserPromptMessage(content="question")] + memory.get_history_prompt_messages.return_value = history + + with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token: + result = _handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=_build_prepared_llm_mock(), + ) + + assert result == history + mock_rest_token.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f9..05c857188f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -5,9 +5,11 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode from dify_graph.runtime import GraphRuntimeState +from dify_graph.template_rendering import TemplateRenderError from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +64,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +80,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +93,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +113,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -130,6 +132,26 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" + @pytest.mark.parametrize("max_output_length", [0, -1]) + def test_node_initialization_rejects_non_positive_max_output_length( + self, + basic_node_data, + mock_graph_runtime_state, + graph_init_params, + max_output_length, + ): + mock_renderer = MagicMock() + + with pytest.raises(ValueError, match="max_output_length must be a positive integer"): + TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=max_output_length, + ) + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool @@ -153,7 +175,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +203,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +223,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +243,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -230,6 +252,28 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error + def test_run_output_length_equal_to_limit_succeeds( + self, basic_node_data, mock_graph_runtime_state, graph_init_params + ): + mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "1234567890" + + node = TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=10, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["output"] == "1234567890" + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { @@ -263,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -291,6 +335,69 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] + def test_extract_variable_selector_to_variable_mapping_accepts_validated_node_data(self): + node_data = TemplateTransformNodeData( + title="Test", + variables=[VariableSelector(variable="var1", value_selector=["sys", "input1"])], + template="{{ var1 }}", + ) + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + + def test_extract_variable_selector_to_variable_mapping_returns_empty_mapping_without_variables(self): + node_data = { + "title": "Test", + "template": "{{ missing }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {} + + def test_extract_variable_selector_to_variable_mapping_accepts_sequence_value_selectors(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ("sys", "input1")}, + {"variable": "empty_selector", "value_selector": ()}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == { + "node_123.var1": ["sys", "input1"], + "node_123.empty_selector": [], + } + + def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ["sys", "input1"]}, + {"variable": "missing_selector"}, + ["not", "a", "mapping"], + {"variable": 1, "value_selector": ["sys", "input2"]}, + {"variable": "invalid_selector", "value_selector": ["sys", 2]}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -307,7 +414,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +453,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +482,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +512,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py new file mode 100644 index 0000000000..0e783bec26 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from dify_graph.nodes.base.entities import VariableSelector +from dify_graph.nodes.template_transform.template_transform_node import ( + DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, + TemplateTransformNode, +) +from dify_graph.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params + +from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 + + +@pytest.fixture +def graph_init_params(): + return build_test_graph_init_params( + workflow_id="test_workflow", + graph_config={}, + tenant_id="test_tenant", + app_id="test_app", + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + mock_state = MagicMock(spec=GraphRuntimeState) + mock_state.variable_pool = MagicMock() + return mock_state + + +def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): + node = TemplateTransformNode( + id="test_node", + config={ + "id": "test_node", + "data": { + "title": "Template Transform", + "variables": [], + "template": "hello", + }, + }, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=MagicMock(), + ) + + assert node._max_output_length == DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH + + +def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entries(): + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={"ignored": True}, + node_id="node_123", + node_data={ + "variables": [ + VariableSelector(variable="validated", value_selector=["sys", "input1"]), + {"variable": "raw", "value_selector": ("sys", "input2")}, + {"variable": "invalid_selector", "value_selector": ["sys", 3]}, + ["not", "a", "mapping"], + ] + }, + ) + + assert mapping == { + "node_123.validated": ["sys", "input1"], + "node_123.raw": ["sys", "input2"], + } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b..69365c227d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -3,13 +3,14 @@ from collections.abc import Mapping import pytest from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.entities.base_node_data import BaseNodeData from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import BuiltinNodeTypes from dify_graph.nodes.base.node import Node from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -80,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None @@ -127,3 +128,29 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" assert node_data.get("missing", "fallback") == "fallback" + + +def test_node_hydration_preserves_compatibility_extra_fields(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + node_config = NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + "compat_flag": True, + }, + } + ) + + node = _SampleNode( + id="node-1", + config=node_config, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.node_data.get("compat_flag") is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fe..1bc0bb8cb5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -4,16 +4,15 @@ from unittest.mock import MagicMock, Mock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.workflow.system_variables import build_system_variables from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.graph import Graph from dify_graph.nodes.if_else.entities import IfElseNodeData from dify_graph.nodes.if_else.if_else_node import IfElseNode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition from dify_graph.variables import ArrayFileSegment from extensions.ext_database import db @@ -35,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -142,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -253,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -316,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -371,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -440,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2..d1c25da489 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -2,8 +2,7 @@ from unittest.mock import MagicMock import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from dify_graph.enums import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.nodes.list_operator.entities import ( @@ -72,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -80,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -88,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -96,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -120,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -136,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -144,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -165,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py index 6372583839..77ec5ac128 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py @@ -1,6 +1,22 @@ +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest + from dify_graph.entities.graph_config import NodeConfigDictAdapter +from dify_graph.enums import WorkflowNodeExecutionStatus +from dify_graph.graph_events import GraphRunAbortedEvent +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.node_events import LoopFailedEvent, LoopStartedEvent, StreamCompletedEvent from dify_graph.nodes.loop.entities import LoopNodeData from dify_graph.nodes.loop.loop_node import LoopNode +from tests.workflow_test_utils import build_test_variable_pool + + +def _usage_with_tokens(total_tokens: int) -> LLMUsage: + usage = LLMUsage.empty_usage() + usage.total_tokens = total_tokens + return usage def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: @@ -50,3 +66,85 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf ) assert seen_configs == [child_node_config] + + +def test_run_single_loop_raises_on_child_abort_event() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + + graph_engine = SimpleNamespace( + run=lambda: iter([GraphRunAbortedEvent(reason="quota exceeded")]), + ) + + with pytest.raises(RuntimeError, match="quota exceeded"): + list(node._run_single_loop(graph_engine=graph_engine, current_index=0)) + + +def test_loop_run_fails_on_child_abort_and_stops_subsequent_rounds() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=2, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=LLMUsage.empty_usage()), + ) + create_graph_engine = MagicMock(return_value=aborting_engine) + node._create_graph_engine = create_graph_engine + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[0], LoopStartedEvent) + assert isinstance(events[1], LoopFailedEvent) + assert events[1].error == "quota exceeded" + assert isinstance(events[2], StreamCompletedEvent) + assert events[2].node_run_result.status == WorkflowNodeExecutionStatus.FAILED + assert events[2].node_run_result.error == "quota exceeded" + create_graph_engine.assert_called_once() + + +def test_loop_run_merges_child_usage_before_failing_on_child_abort() -> None: + node = LoopNode.__new__(LoopNode) + node._node_id = "loop-node" + node._node_data = LoopNodeData( + title="Loop", + loop_count=1, + break_conditions=[], + logical_operator="and", + start_node_id="child-start", + ) + node.graph_config = {"nodes": [], "edges": []} + node.graph_runtime_state = SimpleNamespace( + variable_pool=build_test_variable_pool(), + llm_usage=LLMUsage.empty_usage(), + ) + + aborting_engine = SimpleNamespace( + graph_runtime_state=SimpleNamespace(outputs={}, llm_usage=_usage_with_tokens(7)), + ) + node._create_graph_engine = MagicMock(return_value=aborting_engine) + node._run_single_loop = lambda *, graph_engine, current_index: (_ for _ in ()).throw(RuntimeError("quota exceeded")) + + events = list(node._run()) + + assert isinstance(events[-1], StreamCompletedEvent) + assert events[-1].node_run_result.llm_usage.total_tokens == 7 + assert node.graph_runtime_state.llm_usage.total_tokens == 7 diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py index c5a02e87e4..fc7176de46 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py @@ -2,12 +2,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer +from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory from dify_graph.nodes.protocols import HttpClientProtocol from dify_graph.nodes.question_classifier import ( QuestionClassifierNode, QuestionClassifierNodeData, ) +from dify_graph.template_rendering import Jinja2TemplateRenderer from tests.workflow_test_utils import build_test_graph_init_params @@ -86,7 +87,7 @@ def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(mon "instruction": "This is a test instruction", } ) - template_renderer = MagicMock(spec=TemplateRenderer) + template_renderer = MagicMock(spec=Jinja2TemplateRenderer) node = QuestionClassifierNode( id="node-id", config={"id": "node-id", "data": node_data.model_dump(mode="json")}, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91..80ee3858ae 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -4,19 +4,22 @@ import time import pytest from pydantic import ValidationError as PydanticValidationError +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from dify_graph.nodes.start.entities import StartNodeData from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from dify_graph.runtime import GraphRuntimeState +from dify_graph.variables import build_segment, segment_to_variable from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from dify_graph.variables.variables import Variable +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { @@ -232,3 +235,64 @@ def test_json_object_optional_variable_not_provided(): # Current implementation raises a validation error even when the variable is optional with pytest.raises(ValueError, match="profile is required in input form"): node._run() + + +def test_start_node_outputs_full_variable_pool_snapshot(): + variable_pool = build_test_variable_pool( + variables=[ + *build_system_variables(query="hello", workflow_run_id="run-123"), + _build_prefixed_variable(ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY", "secret"), + _build_prefixed_variable(CONVERSATION_VARIABLE_NODE_ID, "session_id", "conversation-1"), + ], + node_id="start", + inputs={"profile": {"age": 20, "name": "Tom"}}, + ) + + config = { + "id": "start", + "data": StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ).model_dump(), + } + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node = StartNode( + id="start", + config=config, + graph_init_params=build_test_graph_init_params( + workflow_id="wf", + graph_config={}, + tenant_id="tenant", + app_id="app", + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + result = node._run() + + assert result.inputs == {"profile": {"age": 20, "name": "Tom"}} + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + assert result.outputs["sys.query"] == "hello" + assert result.outputs["sys.workflow_run_id"] == "run-123" + assert result.outputs["env.API_KEY"] == "secret" + assert result.outputs["conversation.session_id"] == "conversation-1" + + +def _build_prefixed_variable(node_id: str, name: str, value: object) -> Variable: + return segment_to_variable( + segment=build_segment(value), + selector=(node_id, name), + name=name, + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef..2a954fbaf8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,18 +3,18 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.system_variables import build_system_variables from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.model_runtime.entities.llm_entities import LLMUsage from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables.segments import ArrayFileSegment from tests.workflow_test_utils import build_test_graph_init_params @@ -22,6 +22,38 @@ if TYPE_CHECKING: # pragma: no cover - imported for type checking only from dify_graph.nodes.tool.tool_node import ToolNode +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None + + @pytest.fixture def tool_node(monkeypatch) -> ToolNode: module_name = "core.ops.ops_trace_manager" @@ -66,13 +98,14 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,29 +127,19 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -125,9 +149,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +174,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) @@ -167,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 0000000000..0671db3683 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from dify_graph.model_runtime.entities.llm_entities import LLMUsage +from dify_graph.nodes.tool.entities import ToolNodeData, ToolProviderType +from dify_graph.nodes.tool.exc import ToolRuntimeInvocationError +from dify_graph.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from dify_graph.runtime import VariablePool +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() + + with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ) as transform_tool_messages, + ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) + messages = list( + runtime.invoke( + tool_runtime=tool_runtime, + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + ) + + assert not hasattr(tool_runtime, "conversation_id") + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = _build_tool_node_data() + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + assert not hasattr(tool_runtime, "conversation_id") + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN + + +def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: + tool_runtime = ToolRuntimeHandle( + raw=SimpleNamespace( + get_merged_runtime_parameters=MagicMock( + return_value=[ + SimpleNamespace(name="city", required=True), + SimpleNamespace(name="country", required=False), + ] + ) + ) + ) + + parameters = runtime.get_runtime_parameters(tool_runtime=tool_runtime) + + assert [(parameter.name, parameter.required) for parameter in parameters] == [ + ("city", True), + ("country", False), + ] + + +def test_get_usage_returns_empty_usage_when_tool_has_no_usage(runtime: DifyToolNodeRuntime) -> None: + usage = runtime.get_usage(tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=None))) + + assert usage == LLMUsage.empty_usage() + + +@pytest.mark.parametrize( + ("payload", "expected_type"), + [ + (ToolInvokeMessage.JsonMessage(json_object={"ok": True}, suppress_output=True), ToolRuntimeMessage.JsonMessage), + (ToolInvokeMessage.BlobMessage(blob=b"bytes"), ToolRuntimeMessage.BlobMessage), + ( + ToolInvokeMessage.BlobChunkMessage( + id="blob-id", + sequence=1, + total_length=5, + blob=b"hello", + end=True, + ), + ToolRuntimeMessage.BlobChunkMessage, + ), + (ToolInvokeMessage.FileMessage(file_marker="marker"), ToolRuntimeMessage.FileMessage), + ( + ToolInvokeMessage.VariableMessage(variable_name="city", variable_value="Tokyo", stream=True), + ToolRuntimeMessage.VariableMessage, + ), + ( + ToolInvokeMessage.LogMessage( + id="log-id", + label="lookup", + status=ToolInvokeMessage.LogMessage.LogStatus.SUCCESS, + data={"count": 1}, + metadata={"source": "tool"}, + ), + ToolRuntimeMessage.LogMessage, + ), + ], +) +def test_convert_message_payload_supports_runtime_message_types( + runtime: DifyToolNodeRuntime, + payload: object, + expected_type: type[object], +) -> None: + message = runtime._convert_message_payload(payload) + + assert isinstance(message, expected_type) + + +def test_convert_message_payload_rejects_unknown_types(runtime: DifyToolNodeRuntime) -> None: + with pytest.raises(TypeError, match="unsupported tool message payload"): + runtime._convert_message_payload(object()) + + +def test_resolve_provider_icons_prefers_builtin_tool_icons(runtime: DifyToolNodeRuntime) -> None: + plugin = SimpleNamespace( + plugin_id="langgenius/tools", + name="search", + declaration=SimpleNamespace(icon={"plugin": "icon"}), + ) + builtin_tool = SimpleNamespace( + name="langgenius/tools/search", + icon={"builtin": "icon"}, + icon_dark={"builtin": "dark"}, + ) + + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[builtin_tool]), + ): + installer_cls.return_value.list_plugins.return_value = [plugin] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="langgenius/tools/search") + + assert icon == {"builtin": "icon"} + assert icon_dark == {"builtin": "dark"} + + +def test_resolve_provider_icons_returns_default_when_provider_is_unknown(runtime: DifyToolNodeRuntime) -> None: + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[]), + ): + installer_cls.return_value.list_plugins.return_value = [] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="unknown", default_icon="fallback") + + assert icon == "fallback" + assert icon_dark is None + + +@pytest.mark.parametrize( + ("exc", "message"), + [ + (PluginDaemonClientSideError("bad request"), "Failed to invoke tool, error: bad request"), + (ToolInvokeError("broken"), "Failed to invoke tool provider: broken"), + (RuntimeError("unexpected"), "unexpected"), + ], +) +def test_map_invocation_exception_normalizes_runtime_errors( + runtime: DifyToolNodeRuntime, + exc: Exception, + message: str, +) -> None: + error = runtime._map_invocation_exception(exc, provider_name="provider") + + assert isinstance(error, ToolRuntimeInvocationError) + assert str(error) == message diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 9aeab0409e..44f919efec 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -2,12 +2,12 @@ from collections.abc import Mapping from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode +from core.workflow.system_variables import build_system_variables from dify_graph.entities import GraphInitParams from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from dify_graph.runtime import GraphRuntimeState +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: @@ -17,9 +17,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable(user_id="user", files=[]), - user_inputs={"payload": "value"}, + variable_pool=build_test_variable_pool( + variables=build_system_variables(user_id="user", files=[]), + node_id="node-1", + inputs={"payload": "value"}, ), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index e69c05dc0b..1d09324927 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -2,22 +2,38 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent +from dify_graph.graph_events.node import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent from dify_graph.nodes.variable_assigner.common import helpers as common_helpers from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayStringVariable, StringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool( + *, + conversation_id: str, + conversation_variables: list[StringVariable | ArrayStringVariable], +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id=conversation_id), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_overwrite_string_variable(): graph_config = { "edges": [ @@ -71,10 +87,8 @@ def test_overwrite_string_variable(): conversation_id = str(uuid.uuid4()) # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -108,16 +122,14 @@ def test_overwrite_string_variable(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" + assert updated_event.variable.value == "the second value" + assert tuple(updated_event.variable.selector) == ("conversation", conversation_variable.name) def test_append_variable_to_array(): @@ -172,10 +184,8 @@ def test_append_variable_to_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) variable_pool.add( @@ -208,15 +218,13 @@ def test_append_variable_to_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] + assert updated_event.variable.value == ["the first value", "the second value"] def test_clear_array(): @@ -265,10 +273,8 @@ def test_clear_array(): ) conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], + variable_pool = _build_variable_pool( + conversation_id=conversation_id, conversation_variables=[conversation_variable], ) @@ -297,12 +303,10 @@ def test_clear_array(): ) events = list(node.run()) + updated_event = next(event for event in events if isinstance(event, NodeRunVariableUpdatedEvent)) succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) assert updated_variables is not None assert updated_variables[0].name == conversation_variable.name assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + assert updated_event.variable.value == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 6874f3fef1..f61aea98bc 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -2,20 +2,33 @@ import time import uuid from uuid import uuid4 -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY from dify_graph.graph import Graph +from dify_graph.graph_events import NodeRunVariableUpdatedEvent from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import ArrayStringVariable DEFAULT_NODE_ID = "node_id" +def _build_variable_pool(*, conversation_variables: list[ArrayStringVariable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=build_system_variables(conversation_id="conversation_id"), + conversation_variables=conversation_variables, + ), + ) + return variable_pool + + def test_handle_item_directly(): """Test the _handle_item method directly for remove operations.""" # Create variables @@ -106,12 +119,7 @@ def test_remove_first_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -146,11 +154,8 @@ def test_remove_first_from_array(): # Run the node result = list(node.run()) - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["second", "third"] def test_remove_last_from_array(): @@ -194,12 +199,7 @@ def test_remove_last_from_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -231,11 +231,9 @@ def test_remove_last_from_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == ["first", "second"] def test_remove_first_from_empty_array(): @@ -279,12 +277,7 @@ def test_remove_first_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -316,11 +309,9 @@ def test_remove_first_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_remove_last_from_empty_array(): @@ -364,12 +355,7 @@ def test_remove_last_from_empty_array(): selector=["conversation", "test_conversation_variable"], ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) + variable_pool = _build_variable_pool(conversation_variables=[conversation_variable]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( @@ -401,11 +387,9 @@ def test_remove_last_from_empty_array(): config=node_config, ) - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] + result = list(node.run()) + updated_event = next(event for event in result if isinstance(event, NodeRunVariableUpdatedEvent)) + assert updated_event.variable.value == [] def test_node_factory_creates_variable_assigner_node(): @@ -433,12 +417,7 @@ def test_node_factory_creates_variable_assigner_node(): }, call_depth=0, ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(conversation_variables=[]) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) node_factory = DifyNodeFactory( diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3..9a0487de5a 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,7 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,11 +16,12 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.workflow.system_variables import default_system_variables +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,14 +134,14 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory and variable factory + # Mock the file reference boundary and variable factory with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -153,8 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) # Verify segment factory was called to create FileSegment @@ -184,16 +195,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +229,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +239,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +265,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +275,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +308,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,13 +322,13 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -350,8 +357,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) @@ -370,9 +376,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,13 +388,13 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -430,9 +435,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +445,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3..b19fc9f29f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -2,7 +2,7 @@ from unittest.mock import patch import pytest -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,13 +12,14 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams +from core.workflow.system_variables import default_system_variables +from dify_graph.entities.graph_init_params import GraphInitParams from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus from dify_graph.file import File, FileTransferMethod, FileType from dify_graph.runtime.graph_runtime_state import GraphRuntimeState from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable from dify_graph.variables import FileVariable, StringVariable +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +63,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +85,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +125,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +137,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +160,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +171,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +195,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +208,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +225,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +234,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +251,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,14 +262,14 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -284,7 +284,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,23 +302,22 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -343,10 +341,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +364,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +376,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +393,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +423,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py new file mode 100644 index 0000000000..50b03645ef --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +from pydantic import BaseModel + +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + is_human_input_webapp_enabled, + normalize_human_input_node_data_for_graph, + normalize_node_config_for_graph, + normalize_node_data_for_graph, + parse_human_input_delivery_methods, +) +from dify_graph.enums import BuiltinNodeTypes + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert "