fix(api): align graphon node constructors

This commit is contained in:
-LAN-
2026-05-14 16:31:28 +08:00
parent 13597e421f
commit 5d90cd41a6
8 changed files with 31 additions and 14 deletions

View File

@ -431,7 +431,7 @@ class DifyNodeFactory(NodeFactory):
include_jinja2_template_renderer=False,
),
BuiltinNodeTypes.TOOL: lambda: {
"tool_file_manager_factory": self._bound_tool_file_manager_factory(),
"tool_file_manager": self._bound_tool_file_manager_factory(),
"runtime": self._tool_runtime,
},
BuiltinNodeTypes.AGENT: lambda: {

View File

@ -29,7 +29,11 @@ from core.workflow.node_factory import (
get_node_type_classes_mapping,
is_start_node_type,
)
from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient
from core.workflow.node_runtime import (
DifyFileReferenceFactory,
DifyHumanInputNodeRuntime,
apply_dify_debug_email_recipient,
)
from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables
from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool
from core.workflow.workflow_entry import WorkflowEntry
@ -1259,6 +1263,7 @@ class WorkflowService:
data=node_data,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context),
runtime=DifyHumanInputNodeRuntime(run_context),
)
return node

View File

@ -60,14 +60,14 @@ def init_tool_node(config: dict):
graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start")
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
tool_file_manager = MagicMock(spec=ToolFileManagerProtocol)
node = ToolNode(
node_id=str(uuid.uuid4()),
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
tool_file_manager=tool_file_manager,
runtime=DifyToolNodeRuntime(init_params.run_context),
)
return node

View File

@ -13,7 +13,7 @@ from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenc
from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.workflow.node_runtime import DifyHumanInputNodeRuntime
from core.workflow.node_runtime import DifyFileReferenceFactory, DifyHumanInputNodeRuntime
from core.workflow.system_variables import build_system_variables
from graphon.enums import WorkflowType
from graphon.graph import Graph
@ -121,6 +121,7 @@ def _build_graph(
graph_init_params=params,
graph_runtime_state=runtime_state,
form_repository=form_repository,
file_reference_factory=DifyFileReferenceFactory(params.run_context),
runtime=DifyHumanInputNodeRuntime(params.run_context),
)

View File

@ -81,11 +81,11 @@ class MockNodeMixin:
if isinstance(self, TemplateTransformNode):
kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer())
# Provide default tool_file_manager_factory for ToolNode subclasses
# Provide default ToolNode dependencies for ToolNode subclasses.
from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles
if isinstance(self, _ToolNode):
kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol))
kwargs.setdefault("tool_file_manager", MagicMock(spec=ToolFileManagerProtocol))
kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context))
if isinstance(self, AgentNode):

View File

@ -111,8 +111,8 @@ def tool_node(monkeypatch) -> ToolNode:
config = graph_config["nodes"][0]
# Provide a stub ToolFileManager to satisfy the updated ToolNode constructor
tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol)
# Provide a stub ToolFileManager to satisfy the ToolNode constructor.
tool_file_manager = MagicMock(spec=ToolFileManagerProtocol)
runtime = _StubToolRuntime()
node = ToolNode(
@ -120,7 +120,7 @@ def tool_node(monkeypatch) -> ToolNode:
data=ToolNodeData.model_validate(config["data"]),
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
tool_file_manager_factory=tool_file_manager_factory,
tool_file_manager=tool_file_manager,
runtime=runtime,
)
return node
@ -215,7 +215,7 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
size=123,
storage_key="file-key",
)
tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = (
tool_node._tool_file_manager.get_file_generator_by_tool_file_id.return_value = (
None,
SimpleNamespace(mime_type="application/pdf"),
)
@ -228,7 +228,7 @@ def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode):
events, _ = _run_transform(tool_node, message)
tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id")
tool_node._tool_file_manager.get_file_generator_by_tool_file_id.assert_called_once_with("file-id")
completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)]
assert len(completed_events) == 1
files_segment = completed_events[0].node_run_result.outputs["files"]

View File

@ -418,7 +418,7 @@ class TestDifyNodeFactoryCreateNode:
factory._jinja2_template_renderer = sentinel.jinja2_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._bound_tool_file_manager_factory = sentinel.tool_file_manager_factory
factory._bound_tool_file_manager_factory = MagicMock(return_value=sentinel.tool_file_manager)
factory._file_reference_factory = sentinel.file_reference_factory
factory._prompt_message_serializer = sentinel.prompt_message_serializer
factory._retriever_attachment_loader = sentinel.retriever_attachment_loader
@ -505,6 +505,7 @@ class TestDifyNodeFactoryCreateNode:
(BuiltinNodeTypes.TEMPLATE_TRANSFORM, "TemplateTransformNode"),
(BuiltinNodeTypes.HTTP_REQUEST, "HttpRequestNode"),
(BuiltinNodeTypes.HUMAN_INPUT, "HumanInputNode"),
(BuiltinNodeTypes.TOOL, "ToolNode"),
(KNOWLEDGE_INDEX_NODE_TYPE, "KnowledgeIndexNode"),
(BuiltinNodeTypes.DATASOURCE, "DatasourceNode"),
(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, "KnowledgeRetrievalNode"),
@ -545,14 +546,19 @@ class TestDifyNodeFactoryCreateNode:
elif constructor_name == "HttpRequestNode":
assert kwargs["http_request_config"] is sentinel.http_request_config
assert kwargs["http_client"] is sentinel.http_client
assert kwargs["tool_file_manager_factory"] is sentinel.tool_file_manager_factory
assert kwargs["tool_file_manager_factory"] is factory._bound_tool_file_manager_factory
assert kwargs["file_manager"] is sentinel.file_manager
assert kwargs["file_reference_factory"] is sentinel.file_reference_factory
factory._bound_tool_file_manager_factory.assert_not_called()
elif constructor_name == "HumanInputNode":
assert kwargs["form_repository"] is form_repository
assert kwargs["file_reference_factory"] is sentinel.file_reference_factory
assert kwargs["runtime"] is factory._human_input_runtime
factory._human_input_runtime.build_form_repository.assert_called_once_with()
elif constructor_name == "ToolNode":
assert kwargs["tool_file_manager"] is sentinel.tool_file_manager
assert kwargs["runtime"] is sentinel.tool_runtime
factory._bound_tool_file_manager_factory.assert_called_once_with()
elif constructor_name == "DocumentExtractorNode":
assert kwargs["unstructured_api_config"] is sentinel.unstructured_api_config
assert kwargs["http_client"] is sentinel.http_client

View File

@ -2833,6 +2833,7 @@ class TestWorkflowServiceFreeNodeExecution:
return_value=sentinel.adapted_node_data,
) as mock_adapt_node_data,
patch("services.workflow_service.build_dify_run_context") as mock_build_dify_run_context,
patch("services.workflow_service.DifyFileReferenceFactory") as mock_file_reference_factory_cls,
patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls,
patch("services.workflow_service.HumanInputNode") as mock_node_cls,
):
@ -2851,10 +2852,14 @@ class TestWorkflowServiceFreeNodeExecution:
mock_runtime_cls.assert_called_once_with(mock_build_dify_run_context.return_value)
mock_adapt_node_data.assert_called_once_with(node_config["data"])
mock_node_cls.validate_node_data.assert_called_once_with(sentinel.adapted_node_data)
mock_file_reference_factory_cls.assert_called_once_with(
mock_graph_init_context_cls.return_value.to_graph_init_params.return_value.run_context
)
mock_node_cls.assert_called_once_with(
node_id="n-1",
data=sentinel.node_data,
graph_init_params=mock_graph_init_context_cls.return_value.to_graph_init_params.return_value,
graph_runtime_state=ANY,
file_reference_factory=mock_file_reference_factory_cls.return_value,
runtime=mock_runtime_cls.return_value,
)