mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
Refactor: centralize node data hydration (#27771)
This commit is contained in:
@ -69,10 +69,6 @@ def init_code_node(code_config: dict):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in code_config:
|
||||
node.init_node_data(code_config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -65,10 +65,6 @@ def init_http_node(config: dict):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@ -709,10 +705,6 @@ def test_nested_object_variable_selector(setup_http_mock):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in graph_config["nodes"][1]:
|
||||
node.init_node_data(graph_config["nodes"][1]["data"])
|
||||
|
||||
result = node._run()
|
||||
assert result.process_data is not None
|
||||
data = result.process_data.get("request", "")
|
||||
|
||||
@ -82,10 +82,6 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
if "data" in config:
|
||||
node.init_node_data(config["data"])
|
||||
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -85,7 +85,6 @@ def init_parameter_extractor_node(config: dict):
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -82,7 +82,6 @@ def test_execute_code(setup_code_executor_mock):
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
|
||||
# execute node
|
||||
result = node._run()
|
||||
|
||||
@ -62,7 +62,6 @@ def init_tool_node(config: dict):
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@ -12,14 +11,19 @@ from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph.validation import GraphValidationError
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class _TestNode(Node):
|
||||
class _TestNodeData(BaseNodeData):
|
||||
type: NodeType | str | None = None
|
||||
execution_type: NodeExecutionType | str | None = None
|
||||
|
||||
|
||||
class _TestNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
@ -41,31 +45,8 @@ class _TestNode(Node):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
data = config.get("data", {})
|
||||
if isinstance(data, Mapping):
|
||||
execution_type = data.get("execution_type")
|
||||
if isinstance(execution_type, str):
|
||||
self.execution_type = NodeExecutionType(execution_type)
|
||||
self._base_node_data = BaseNodeData(title=str(data.get("title", self.id)))
|
||||
self.data: dict[str, object] = {}
|
||||
|
||||
def init_node_data(self, data: Mapping[str, object]) -> None:
|
||||
title = str(data.get("title", self.id))
|
||||
desc = data.get("description")
|
||||
error_strategy_value = data.get("error_strategy")
|
||||
error_strategy: ErrorStrategy | None = None
|
||||
if isinstance(error_strategy_value, ErrorStrategy):
|
||||
error_strategy = error_strategy_value
|
||||
elif isinstance(error_strategy_value, str):
|
||||
error_strategy = ErrorStrategy(error_strategy_value)
|
||||
self._base_node_data = BaseNodeData(
|
||||
title=title,
|
||||
desc=str(desc) if desc is not None else None,
|
||||
error_strategy=error_strategy,
|
||||
)
|
||||
self.data = dict(data)
|
||||
|
||||
node_type_value = data.get("type")
|
||||
node_type_value = self.data.get("type")
|
||||
if isinstance(node_type_value, NodeType):
|
||||
self.node_type = node_type_value
|
||||
elif isinstance(node_type_value, str):
|
||||
@ -77,23 +58,19 @@ class _TestNode(Node):
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return self._base_node_data.error_strategy
|
||||
def post_init(self) -> None:
|
||||
super().post_init()
|
||||
self._maybe_override_execution_type()
|
||||
self.data = dict(self.node_data.model_dump())
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._base_node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._base_node_data.title
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return self._base_node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._base_node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._base_node_data
|
||||
def _maybe_override_execution_type(self) -> None:
|
||||
execution_type_value = self.node_data.execution_type
|
||||
if execution_type_value is None:
|
||||
return
|
||||
if isinstance(execution_type_value, NodeExecutionType):
|
||||
self.execution_type = execution_type_value
|
||||
else:
|
||||
self.execution_type = NodeExecutionType(execution_type_value)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -109,7 +86,6 @@ class _SimpleNodeFactory:
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(node_config.get("data", {}))
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -32,7 +32,7 @@ def test_abort_command():
|
||||
# Create mock nodes with required attributes - using shared runtime state
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start"},
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@ -45,7 +45,6 @@ def test_abort_command():
|
||||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
start_node.init_node_data({"title": "start", "variables": []})
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
# Mock graph methods
|
||||
@ -142,7 +141,7 @@ def test_pause_command():
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start"},
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
@ -155,7 +154,6 @@ def test_pause_command():
|
||||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
start_node.init_node_data({"title": "start", "variables": []})
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
|
||||
@ -63,7 +63,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
@ -88,7 +87,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
@ -105,7 +103,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
@ -125,7 +122,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
@ -142,7 +138,6 @@ def _build_branching_graph(mock_config: MockConfig) -> tuple[Graph, GraphRuntime
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
||||
@ -62,7 +62,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
@ -87,7 +86,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt")
|
||||
@ -104,7 +102,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
human_node.init_node_data(human_config["data"])
|
||||
|
||||
llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt")
|
||||
|
||||
@ -123,7 +120,6 @@ def _build_llm_human_llm_graph(mock_config: MockConfig) -> tuple[Graph, GraphRun
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
||||
@ -62,7 +62,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_config["data"])
|
||||
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
@ -87,7 +86,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
llm_node.init_node_data(llm_config["data"])
|
||||
return llm_node
|
||||
|
||||
llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream")
|
||||
@ -118,7 +116,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if_else_node.init_node_data(if_else_config["data"])
|
||||
|
||||
llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output")
|
||||
llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary")
|
||||
@ -138,7 +135,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_primary.init_node_data(end_primary_config["data"])
|
||||
|
||||
end_secondary_data = EndNodeData(
|
||||
title="End Secondary",
|
||||
@ -155,7 +151,6 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
end_secondary.init_node_data(end_secondary_config["data"])
|
||||
|
||||
graph = (
|
||||
Graph.new()
|
||||
|
||||
@ -111,9 +111,6 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
mock_config=self.mock_config,
|
||||
)
|
||||
|
||||
# Initialize node with provided data
|
||||
mock_instance.init_node_data(node_data)
|
||||
|
||||
return mock_instance
|
||||
|
||||
# For non-mocked node types, use parent implementation
|
||||
|
||||
@ -142,6 +142,8 @@ def test_mock_loop_node_preserves_config():
|
||||
"start_node_id": "node1",
|
||||
"loop_variables": [],
|
||||
"outputs": {},
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@ -63,7 +63,6 @@ class TestMockTemplateTransformNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -125,7 +124,6 @@ class TestMockTemplateTransformNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -184,7 +182,6 @@ class TestMockTemplateTransformNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -246,7 +243,6 @@ class TestMockTemplateTransformNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -311,7 +307,6 @@ class TestMockCodeNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -376,7 +371,6 @@ class TestMockCodeNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
@ -445,7 +439,6 @@ class TestMockCodeNode:
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
)
|
||||
mock_node.init_node_data(node_config["data"])
|
||||
|
||||
# Run the node
|
||||
result = mock_node._run()
|
||||
|
||||
@ -83,9 +83,6 @@ def test_execute_answer():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
|
||||
@ -1,4 +1,7 @@
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
# Ensures that all node classes are imported.
|
||||
@ -7,6 +10,12 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
|
||||
_ = NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
|
||||
class _TestNodeData(BaseNodeData):
|
||||
"""Test node data for unit tests."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def _get_all_subclasses(root: type[Node]) -> list[type[Node]]:
|
||||
subclasses = []
|
||||
queue = [root]
|
||||
@ -34,3 +43,79 @@ def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined
|
||||
node_type_and_version = (node_type, node_version)
|
||||
assert node_type_and_version not in type_version_set
|
||||
type_version_set.add(node_type_and_version)
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_extracts_type():
|
||||
"""When a class inherits from Node[T], it should extract T."""
|
||||
|
||||
class _ConcreteNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
result = _ConcreteNode._extract_node_data_type_from_generic()
|
||||
|
||||
assert result is _TestNodeData
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_returns_none_for_base_node():
|
||||
"""The base Node class itself should return None (no generic parameter)."""
|
||||
result = Node._extract_node_data_type_from_generic()
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_raises_for_non_base_node_data():
|
||||
"""When generic parameter is not a BaseNodeData subtype, should raise TypeError."""
|
||||
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node[str]): # type: ignore[type-arg]
|
||||
pass
|
||||
|
||||
|
||||
def test_extract_node_data_type_from_generic_raises_for_non_type():
|
||||
"""When generic parameter is not a concrete type, should raise TypeError."""
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
with pytest.raises(TypeError, match="must parameterize Node with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node[T]): # type: ignore[type-arg]
|
||||
pass
|
||||
|
||||
|
||||
def test_init_subclass_raises_without_generic_or_explicit_type():
|
||||
"""A subclass must either use Node[T] or explicitly set _node_data_type."""
|
||||
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||
|
||||
class _InvalidNode(Node):
|
||||
pass
|
||||
|
||||
|
||||
def test_init_subclass_rejects_explicit_node_data_type_without_generic():
|
||||
"""Setting _node_data_type explicitly cannot bypass the Node[T] requirement."""
|
||||
with pytest.raises(TypeError, match="must inherit from Node\\[T\\] with a BaseNodeData subtype"):
|
||||
|
||||
class _ExplicitNode(Node):
|
||||
_node_data_type = _TestNodeData
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
|
||||
def test_init_subclass_sets_node_data_type_from_generic():
|
||||
"""Verify that __init_subclass__ sets _node_data_type from the generic parameter."""
|
||||
|
||||
class _AutoNode(Node[_TestNodeData]):
|
||||
node_type = NodeType.CODE
|
||||
|
||||
@staticmethod
|
||||
def version() -> str:
|
||||
return "1"
|
||||
|
||||
assert _AutoNode._node_data_type is _TestNodeData
|
||||
|
||||
@ -111,8 +111,6 @@ def llm_node(
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
@ -498,8 +496,6 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node, mock_file_saver
|
||||
|
||||
|
||||
|
||||
74
api/tests/unit_tests/core/workflow/nodes/test_base_node.py
Normal file
74
api/tests/unit_tests/core/workflow/nodes/test_base_node.py
Normal file
@ -0,0 +1,74 @@
|
||||
from collections.abc import Mapping
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
class _SampleNodeData(BaseNodeData):
|
||||
foo: str
|
||||
|
||||
|
||||
class _SampleNode(Node[_SampleNodeData]):
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "sample-test"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]:
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config=graph_config,
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="debugger",
|
||||
call_depth=0,
|
||||
)
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}),
|
||||
start_at=0.0,
|
||||
)
|
||||
return init_params, runtime_state
|
||||
|
||||
|
||||
def test_node_hydrates_data_during_initialization():
|
||||
graph_config: dict[str, object] = {}
|
||||
init_params, runtime_state = _build_context(graph_config)
|
||||
|
||||
node = _SampleNode(
|
||||
id="node-1",
|
||||
config={"id": "node-1", "data": {"title": "Sample", "foo": "bar"}},
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
assert node.node_data.foo == "bar"
|
||||
assert node.title == "Sample"
|
||||
|
||||
|
||||
def test_missing_generic_argument_raises_type_error():
|
||||
graph_config: dict[str, object] = {}
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
class _InvalidNode(Node): # type: ignore[type-abstract]
|
||||
node_type = NodeType.ANSWER
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self):
|
||||
raise NotImplementedError
|
||||
@ -50,8 +50,6 @@ def document_extractor_node(graph_init_params):
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=Mock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -114,9 +114,6 @@ def test_execute_if_else_result_true():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@ -187,9 +184,6 @@ def test_execute_if_else_result_false():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@ -252,9 +246,6 @@ def test_array_file_contains_file_name():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment(
|
||||
value=[
|
||||
File(
|
||||
@ -347,7 +338,6 @@ def test_execute_if_else_boolean_conditions(condition: Condition):
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@ -417,7 +407,6 @@ def test_execute_if_else_boolean_false_conditions():
|
||||
"data": node_data,
|
||||
},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
@ -487,7 +476,6 @@ def test_execute_if_else_boolean_cases_structure():
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
config={"id": "if-else", "data": node_data},
|
||||
)
|
||||
node.init_node_data(node_data)
|
||||
|
||||
# Mock db.session.close()
|
||||
db.session.close = MagicMock()
|
||||
|
||||
@ -57,8 +57,6 @@ def list_operator_node():
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=MagicMock(),
|
||||
)
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
node.graph_runtime_state = MagicMock()
|
||||
node.graph_runtime_state.variable_pool = MagicMock()
|
||||
return node
|
||||
|
||||
@ -73,7 +73,6 @@ def tool_node(monkeypatch) -> "ToolNode":
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node.init_node_data(config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
||||
@ -101,9 +101,6 @@ def test_overwrite_string_variable():
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = StringVariable(
|
||||
id=conversation_variable.id,
|
||||
@ -203,9 +200,6 @@ def test_append_variable_to_array():
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_value = list(conversation_variable.value)
|
||||
expected_value.append(input_variable.value)
|
||||
@ -296,9 +290,6 @@ def test_clear_array():
|
||||
conv_var_updater_factory=mock_conv_var_updater_factory,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
list(node.run())
|
||||
expected_var = ArrayStringVariable(
|
||||
id=conversation_variable.id,
|
||||
|
||||
@ -139,11 +139,6 @@ def test_remove_first_from_array():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
|
||||
# Run the node
|
||||
result = list(node.run())
|
||||
|
||||
@ -228,10 +223,6 @@ def test_remove_last_from_array():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
@ -313,10 +304,6 @@ def test_remove_first_from_empty_array():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
@ -398,10 +385,6 @@ def test_remove_last_from_empty_array():
|
||||
config=node_config,
|
||||
)
|
||||
|
||||
# Initialize node data
|
||||
node.init_node_data(node_config["data"])
|
||||
|
||||
# Skip the mock assertion since we're in a test environment
|
||||
list(node.run())
|
||||
|
||||
got = variable_pool.get(["conversation", conversation_variable.name])
|
||||
|
||||
@ -47,7 +47,6 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool)
|
||||
),
|
||||
)
|
||||
|
||||
node.init_node_data(node_config["data"])
|
||||
return node
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user