mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
refactor: Unify NodeConfigDict.data and BaseNodeData (#32780)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@ -12,6 +12,8 @@ from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, ge
|
||||
from uuid import uuid4
|
||||
|
||||
from dify_graph.entities import AgentNodeStrategyInit, GraphInitParams
|
||||
from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import (
|
||||
ErrorStrategy,
|
||||
@ -62,8 +64,6 @@ from dify_graph.node_events import (
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import BaseNodeData, RetryConfig
|
||||
|
||||
NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData)
|
||||
_MISSING_RUN_CONTEXT_VALUE = object()
|
||||
|
||||
@ -153,11 +153,11 @@ class Node(Generic[NodeDataT]):
|
||||
Later, in __init__:
|
||||
::
|
||||
|
||||
config["data"] ──► _hydrate_node_data() ──► _node_data_type.model_validate()
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
config["data"] ──► _node_data_type.model_validate(..., from_attributes=True)
|
||||
│
|
||||
▼
|
||||
CodeNodeData instance
|
||||
(stored in self._node_data)
|
||||
|
||||
Example:
|
||||
class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted
|
||||
@ -241,7 +241,7 @@ class Node(Generic[NodeDataT]):
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> None:
|
||||
@ -254,22 +254,21 @@ class Node(Generic[NodeDataT]):
|
||||
self.graph_runtime_state = graph_runtime_state
|
||||
self.state: NodeState = NodeState.UNKNOWN # node execution state
|
||||
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required.")
|
||||
node_id = config["id"]
|
||||
|
||||
self._node_id = node_id
|
||||
self._node_execution_id: str = ""
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
raw_node_data = config.get("data") or {}
|
||||
if not isinstance(raw_node_data, Mapping):
|
||||
raise ValueError("Node config data must be a mapping.")
|
||||
|
||||
self._node_data: NodeDataT = self._hydrate_node_data(raw_node_data)
|
||||
self._node_data = self.validate_node_data(config["data"])
|
||||
|
||||
self.post_init()
|
||||
|
||||
@classmethod
|
||||
def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT:
|
||||
"""Validate shared graph node payloads against the subclass-declared NodeData model."""
|
||||
return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True))
|
||||
|
||||
def post_init(self) -> None:
|
||||
"""Optional hook for subclasses requiring extra initialization."""
|
||||
return
|
||||
@ -342,9 +341,6 @@ class Node(Generic[NodeDataT]):
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@abstractmethod
|
||||
def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
@ -389,8 +385,6 @@ class Node(Generic[NodeDataT]):
|
||||
start_event.provider_id = getattr(self.node_data, "provider_id", "")
|
||||
start_event.provider_type = getattr(self.node_data, "provider_type", "")
|
||||
|
||||
from typing import cast
|
||||
|
||||
from dify_graph.nodes.agent.agent_node import AgentNode
|
||||
from dify_graph.nodes.agent.entities import AgentNodeData
|
||||
|
||||
@ -442,7 +436,7 @@ class Node(Generic[NodeDataT]):
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
config: Mapping[str, Any],
|
||||
config: NodeConfigDict,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""Extracts references variable selectors from node configuration.
|
||||
|
||||
@ -480,13 +474,12 @@ class Node(Generic[NodeDataT]):
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_id = config.get("id")
|
||||
if not node_id:
|
||||
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
|
||||
|
||||
# Pass raw dict data instead of creating NodeData instance
|
||||
node_id = config["id"]
|
||||
node_data = cls.validate_node_data(config["data"])
|
||||
data = cls._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
return data
|
||||
|
||||
@ -496,7 +489,7 @@ class Node(Generic[NodeDataT]):
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
node_data: NodeDataT,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
return {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user