diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index 3480c0fdd6..c3fbc836d6 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -431,7 +431,7 @@ class DifyNodeFactory(NodeFactory): include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "tool_file_manager": self._bound_tool_file_manager_factory(), "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index eb78e0a68b..1b0e10d784 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -29,7 +29,11 @@ from core.workflow.node_factory import ( get_node_type_classes_mapping, is_start_node_type, ) -from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + apply_dify_debug_email_recipient, +) from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry @@ -1259,6 +1263,7 @@ class WorkflowService: data=node_data, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), runtime=DifyHumanInputNodeRuntime(run_context), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 78c12e7ea5..c109be9fae 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -60,14 +60,14 @@ def init_tool_node(config: dict): graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + tool_file_manager = MagicMock(spec=ToolFileManagerProtocol) node = ToolNode( node_id=str(uuid.uuid4()), data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - tool_file_manager_factory=tool_file_manager_factory, + tool_file_manager=tool_file_manager, runtime=DifyToolNodeRuntime(init_params.run_context), ) return node diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index a69dd99adc..103fe88df7 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -13,7 +13,7 @@ from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenc from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.node_runtime import DifyFileReferenceFactory, DifyHumanInputNodeRuntime from core.workflow.system_variables import build_system_variables from graphon.enums import WorkflowType from graphon.graph import Graph @@ -121,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + file_reference_factory=DifyFileReferenceFactory(params.run_context), runtime=DifyHumanInputNodeRuntime(params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index e0eb4e7361..c3e6f5d76c 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -81,11 +81,11 @@ class MockNodeMixin: if isinstance(self, TemplateTransformNode): kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) - # Provide default tool_file_manager_factory for ToolNode subclasses + # Provide default ToolNode dependencies for ToolNode subclasses. from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): - kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("tool_file_manager", MagicMock(spec=ToolFileManagerProtocol)) kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 4d30746e5c..0ee70256d7 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -111,8 +111,8 @@ def tool_node(monkeypatch) -> ToolNode: config = graph_config["nodes"][0] - # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor - tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + # Provide a stub ToolFileManager to satisfy the ToolNode constructor. + tool_file_manager = MagicMock(spec=ToolFileManagerProtocol) runtime = _StubToolRuntime() node = ToolNode( @@ -120,7 +120,7 @@ def tool_node(monkeypatch) -> ToolNode: data=ToolNodeData.model_validate(config["data"]), graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - tool_file_manager_factory=tool_file_manager_factory, + tool_file_manager=tool_file_manager, runtime=runtime, ) return node @@ -215,7 +215,7 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): size=123, storage_key="file-key", ) - tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + tool_node._tool_file_manager.get_file_generator_by_tool_file_id.return_value = ( None, SimpleNamespace(mime_type="application/pdf"), ) @@ -228,7 +228,7 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): events, _ = _run_transform(tool_node, message) - tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + tool_node._tool_file_manager.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] assert len(completed_events) == 1 files_segment = completed_events[0].node_run_result.outputs["files"] diff --git a/api/tests/unit_tests/core/workflow/test_node_factory.py b/api/tests/unit_tests/core/workflow/test_node_factory.py index 0821419067..ccb63f36d3 100644 --- a/api/tests/unit_tests/core/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/workflow/test_node_factory.py @@ -418,7 +418,7 @@ class TestDifyNodeFactoryCreateNode: factory._jinja2_template_renderer = sentinel.jinja2_template_renderer factory._template_transform_max_output_length = 2048 factory._http_request_http_client = sentinel.http_client - factory._bound_tool_file_manager_factory = sentinel.tool_file_manager_factory + factory._bound_tool_file_manager_factory = MagicMock(return_value=sentinel.tool_file_manager) factory._file_reference_factory = sentinel.file_reference_factory factory._prompt_message_serializer = sentinel.prompt_message_serializer factory._retriever_attachment_loader = sentinel.retriever_attachment_loader @@ -505,6 +505,7 @@ class TestDifyNodeFactoryCreateNode: (BuiltinNodeTypes.TEMPLATE_TRANSFORM, "TemplateTransformNode"), (BuiltinNodeTypes.HTTP_REQUEST, "HttpRequestNode"), (BuiltinNodeTypes.HUMAN_INPUT, "HumanInputNode"), + (BuiltinNodeTypes.TOOL, "ToolNode"), (KNOWLEDGE_INDEX_NODE_TYPE, "KnowledgeIndexNode"), (BuiltinNodeTypes.DATASOURCE, "DatasourceNode"), (BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"), @@ -545,14 +546,19 @@ class TestDifyNodeFactoryCreateNode: elif constructor_name == "HttpRequestNode": assert kwargs["http_request_config"] is sentinel.http_request_config assert kwargs["http_client"] is sentinel.http_client - assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory + assert kwargs["tool_file_manager_factory"] is factory._bound_tool_file_manager_factory assert kwargs["file_manager"] is sentinel.file_manager assert kwargs["file_reference_factory"] is sentinel.file_reference_factory + factory._bound_tool_file_manager_factory.assert_not_called() elif constructor_name == "HumanInputNode": assert kwargs["form_repository"] is form_repository assert kwargs["file_reference_factory"] is sentinel.file_reference_factory assert kwargs["runtime"] is factory._human_input_runtime factory._human_input_runtime.build_form_repository.assert_called_once_with() + elif constructor_name == "ToolNode": + assert kwargs["tool_file_manager"] is sentinel.tool_file_manager + assert kwargs["runtime"] is sentinel.tool_runtime + factory._bound_tool_file_manager_factory.assert_called_once_with() elif constructor_name == "DocumentExtractorNode": assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config assert kwargs["http_client"] is sentinel.http_client diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index e152ab923c..f105364094 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -2833,6 +2833,7 @@ class TestWorkflowServiceFreeNodeExecution: return_value=sentinel.adapted_node_data, ) as mock_adapt_node_data, patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context, + patch("services.workflow_service.DifyFileReferenceFactory") as mock_file_reference_factory_cls, patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls, ): @@ -2851,10 +2852,14 @@ class TestWorkflowServiceFreeNodeExecution: mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value) mock_adapt_node_data.assert_called_once_with(node_config["data"]) mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data) + mock_file_reference_factory_cls.assert_called_once_with( + mock_graph_init_context_cls.return_value.to_graph_init_params.return_value.run_context + ) mock_node_cls.assert_called_once_with( node_id="n-1", data=sentinel.node_data, graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value, graph_runtime_state=ANY, + file_reference_factory=mock_file_reference_factory_cls.return_value, runtime=mock_runtime_cls.return_value, )