Refactor: centralize node data hydration (#27771)

This commit is contained in:
-LAN-
2025-11-27 15:41:56 +08:00
committed by GitHub
parent 1b733abe82
commit 13bf6547ee
58 changed files with 381 additions and 899 deletions

View File

@ -69,10 +69,6 @@ def init_code_node(code_config: dict):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in code_config:
node.init_node_data(code_config["data"])
return node

View File

@ -65,10 +65,6 @@ def init_http_node(config: dict):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node
@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock):
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in graph_config["nodes"][1]:
node.init_node_data(graph_config["nodes"][1]["data"])
result = node._run()
assert result.process_data is not None
data = result.process_data.get("request", "")

View File

@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode:
graph_runtime_state=graph_runtime_state,
)
# Initialize node data
if "data" in config:
node.init_node_data(config["data"])
return node

View File

@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node

View File

@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
# execute node
result = node._run()

View File

@ -62,7 +62,6 @@ def init_tool_node(config: dict):
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config.get("data", {}))
return node

View File

@ -3,7 +3,6 @@ from __future__ import annotations
import time
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
import pytest
@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
from core.workflow.graph import Graph
from core.workflow.graph.validation import GraphValidationError
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
class _TestNode(Node):
class _TestNodeData(BaseNodeData):
type: NodeType | str | None = None
execution_type: NodeExecutionType | str | None = None
class _TestNode(Node[_TestNodeData]):
node_type = NodeType.ANSWER
execution_type = NodeExecutionType.EXECUTABLE
@ -41,31 +45,8 @@ class _TestNode(Node):
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
data = config.get("data", {})
if isinstance(data, Mapping):
execution_type = data.get("execution_type")
if isinstance(execution_type, str):
self.execution_type = NodeExecutionType(execution_type)
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
self.data: dict[str, object] = {}
def init_node_data(self, data: Mapping[str, object]) -> None:
title = str(data.get("title", self.id))
desc = data.get("description")
error_strategy_value = data.get("error_strategy")
error_strategy: ErrorStrategy | None = None
if isinstance(error_strategy_value, ErrorStrategy):
error_strategy = error_strategy_value
elif isinstance(error_strategy_value, str):
error_strategy = ErrorStrategy(error_strategy_value)
self._base_node_data = BaseNodeData(
title=title,
desc=str(desc) if desc is not None else None,
error_strategy=error_strategy,
)
self.data = dict(data)
node_type_value = data.get("type")
node_type_value = self.data.get("type")
if isinstance(node_type_value, NodeType):
self.node_type = node_type_value
elif isinstance(node_type_value, str):
@ -77,23 +58,19 @@ class _TestNode(Node):
def _run(self):
raise NotImplementedError
def _get_error_strategy(self) -> ErrorStrategy | None:
return self._base_node_data.error_strategy
def post_init(self) -> None:
super().post_init()
self._maybe_override_execution_type()
self.data = dict(self.node_data.model_dump())
def _get_retry_config(self) -> RetryConfig:
return self._base_node_data.retry_config
def _get_title(self) -> str:
return self._base_node_data.title
def _get_description(self) -> str | None:
return self._base_node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._base_node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._base_node_data
def _maybe_override_execution_type(self) -> None:
execution_type_value = self.node_data.execution_type
if execution_type_value is None:
return
if isinstance(execution_type_value, NodeExecutionType):
self.execution_type = execution_type_value
else:
self.execution_type = NodeExecutionType(execution_type_value)
@dataclass(slots=True)
@ -109,7 +86,6 @@ class _SimpleNodeFactory:
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
)
node.init_node_data(node_config.get("data", {}))
return node

View File

@ -32,7 +32,7 @@ def test_abort_command():
# Create mock nodes with required attributes - using shared runtime state
start_node = StartNode(
id="start",
config={"id": "start"},
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@ -45,7 +45,6 @@ def test_abort_command():
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
# Mock graph methods
@ -142,7 +141,7 @@ def test_pause_command():
start_node = StartNode(
id="start",
config={"id": "start"},
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
@ -155,7 +154,6 @@ def test_pause_command():
),
graph_runtime_state=shared_runtime_state,
)
start_node.init_node_data({"title": "start", "variables": []})
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])

View File

@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()

View File

@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
human_node.init_node_data(human_config["data"])
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_node.init_node_data(end_config["data"])
graph = (
Graph.new()

View File

@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
start_node.init_node_data(start_config["data"])
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
llm_node.init_node_data(llm_config["data"])
return llm_node
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if_else_node.init_node_data(if_else_config["data"])
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_primary.init_node_data(end_primary_config["data"])
end_secondary_data = EndNodeData(
title="End Secondary",
@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
end_secondary.init_node_data(end_secondary_config["data"])
graph = (
Graph.new()

View File

@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory):
mock_config=self.mock_config,
)
# Initialize node with provided data
mock_instance.init_node_data(node_data)
return mock_instance
# For non-mocked node types, use parent implementation

View File

@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config():
"start_node_id": "node1",
"loop_variables": [],
"outputs": {},
"break_conditions": [],
"logical_operator": "and",
},
}

View File

@ -63,7 +63,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -125,7 +124,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -184,7 +182,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -246,7 +243,6 @@ class TestMockTemplateTransformNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -311,7 +307,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -376,7 +371,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()
@ -445,7 +439,6 @@ class TestMockCodeNode:
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
)
mock_node.init_node_data(node_config["data"])
# Run the node
result = mock_node._run()

View File

@ -83,9 +83,6 @@ def test_execute_answer():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -1,4 +1,7 @@
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
# Ensures that all node classes are imported.
@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
_ = NODE_TYPE_CLASSES_MAPPING
class _TestNodeData(BaseNodeData):
"""Test node data for unit tests."""
pass
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
subclasses = []
queue = [root]
@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
node_type_and_version = (node_type, node_version)
assert node_type_and_version not in type_version_set
type_version_set.add(node_type_and_version)
def test_extract_node_data_type_from_generic_extracts_type():
"""When a class inherits from Node[T], it should extract T."""
class _ConcreteNode(Node[_TestNodeData]):
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
result = _ConcreteNode._extract_node_data_type_from_generic()
assert result is _TestNodeData
def test_extract_node_data_type_from_generic_returns_none_for_base_node():
"""The base Node class itself should return None (no generic parameter)."""
result = Node._extract_node_data_type_from_generic()
assert result is None
def test_extract_node_data_type_from_generic_raises_for_non_base_node_data():
"""When generic parameter is not a BaseNodeData subtype, should raise TypeError."""
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
class _InvalidNode(Node[str]): # type: ignore[type-arg]
pass
def test_extract_node_data_type_from_generic_raises_for_non_type():
"""When generic parameter is not a concrete type, should raise TypeError."""
from typing import TypeVar
T = TypeVar("T")
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
class _InvalidNode(Node[T]): # type: ignore[type-arg]
pass
def test_init_subclass_raises_without_generic_or_explicit_type():
"""A subclass must either use Node[T] or explicitly set _node_data_type."""
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
class _InvalidNode(Node):
pass
def test_init_subclass_rejects_explicit_node_data_type_without_generic():
"""Setting _node_data_type explicitly cannot bypass the Node[T] requirement."""
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
class _ExplicitNode(Node):
_node_data_type = _TestNodeData
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
def test_init_subclass_sets_node_data_type_from_generic():
"""Verify that __init_subclass__ sets _node_data_type from the generic parameter."""
class _AutoNode(Node[_TestNodeData]):
node_type = NodeType.CODE
@staticmethod
def version() -> str:
return "1"
assert _AutoNode._node_data_type is _TestNodeData

View File

@ -111,8 +111,6 @@ def llm_node(
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node
@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
# Initialize node data
node.init_node_data(node_config["data"])
return node, mock_file_saver

View File

@ -0,0 +1,74 @@
from collections.abc import Mapping
import pytest
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.node import Node
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
class _SampleNodeData(BaseNodeData):
foo: str
class _SampleNode(Node[_SampleNodeData]):
node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "sample-test"
def _run(self):
raise NotImplementedError
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}),
start_at=0.0,
)
return init_params, runtime_state
def test_node_hydrates_data_during_initialization():
graph_config: dict[str, object] = {}
init_params, runtime_state = _build_context(graph_config)
node = _SampleNode(
id="node-1",
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
graph_init_params=init_params,
graph_runtime_state=runtime_state,
)
assert node.node_data.foo == "bar"
assert node.title == "Sample"
def test_missing_generic_argument_raises_type_error():
graph_config: dict[str, object] = {}
with pytest.raises(TypeError):
class _InvalidNode(Node): # type: ignore[type-abstract]
node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
raise NotImplementedError

View File

@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params):
graph_init_params=graph_init_params,
graph_runtime_state=Mock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
return node

View File

@ -114,9 +114,6 @@ def test_execute_if_else_result_true():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -187,9 +184,6 @@ def test_execute_if_else_result_false():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Mock db.session.close()
db.session.close = MagicMock()
@ -252,9 +246,6 @@ def test_array_file_contains_file_name():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
value=[
File(
@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
graph_runtime_state=graph_runtime_state,
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions():
"data": node_data,
},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()
@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure():
graph_runtime_state=graph_runtime_state,
config={"id": "if-else", "data": node_data},
)
node.init_node_data(node_data)
# Mock db.session.close()
db.session.close = MagicMock()

View File

@ -57,8 +57,6 @@ def list_operator_node():
graph_init_params=graph_init_params,
graph_runtime_state=MagicMock(),
)
# Initialize node data
node.init_node_data(node_config["data"])
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.variable_pool = MagicMock()
return node

View File

@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode":
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
)
node.init_node_data(config["data"])
return node

View File

@ -101,9 +101,6 @@ def test_overwrite_string_variable():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = StringVariable(
id=conversation_variable.id,
@ -203,9 +200,6 @@ def test_append_variable_to_array():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_value = list(conversation_variable.value)
expected_value.append(input_variable.value)
@ -296,9 +290,6 @@ def test_clear_array():
conv_var_updater_factory=mock_conv_var_updater_factory,
)
# Initialize node data
node.init_node_data(node_config["data"])
list(node.run())
expected_var = ArrayStringVariable(
id=conversation_variable.id,

View File

@ -139,11 +139,6 @@ def test_remove_first_from_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
# Run the node
result = list(node.run())
@ -228,10 +223,6 @@ def test_remove_last_from_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
@ -313,10 +304,6 @@ def test_remove_first_from_empty_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
@ -398,10 +385,6 @@ def test_remove_last_from_empty_array():
config=node_config,
)
# Initialize node data
node.init_node_data(node_config["data"])
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])

View File

@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
),
)
node.init_node_data(node_config["data"])
return node