Files
dify/api/dify_graph/runtime/graph_runtime_state.py

632 lines
23 KiB
Python

from __future__ import annotations
import importlib
import json
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.enums import NodeExecutionType, NodeState, NodeType
from dify_graph.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from dify_graph.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
def put(self, item: str) -> None:
"""Enqueue the identifier of a node that is ready to run."""
...
def get(self, timeout: float | None = None) -> str:
"""Return the next node identifier, blocking until available or timeout expires."""
...
def task_done(self) -> None:
"""Signal that the most recently dequeued node has completed processing."""
...
def empty(self) -> bool:
"""Return True when the queue contains no pending nodes."""
...
def qsize(self) -> int:
"""Approximate the number of pending nodes awaiting execution."""
...
def dumps(self) -> str:
"""Serialize the queue contents for persistence."""
...
def loads(self, data: str) -> None:
"""Restore the queue contents from a serialized payload."""
...
class GraphExecutionProtocol(Protocol):
"""Structural interface for graph execution aggregate.
Defines the minimal set of attributes and methods required from a GraphExecution entity
for runtime orchestration and state management.
"""
workflow_id: str
started: bool
completed: bool
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
...
def complete(self) -> None:
"""Mark execution as successfully completed."""
...
def abort(self, reason: str) -> None:
"""Abort execution in response to an external stop request."""
...
def fail(self, error: Exception) -> None:
"""Record an unrecoverable error and end execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
def loads(self, data: str) -> None:
"""Restore execution state from a previously serialized payload."""
...
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""
def register(self, response_node_id: str) -> None:
"""Register a response node so its outputs can be streamed."""
...
def loads(self, data: str) -> None:
"""Restore coordinator state from a serialized payload."""
...
def dumps(self) -> str:
"""Serialize coordinator state for persistence."""
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
execution_type: NodeExecutionType
node_type: ClassVar[NodeType]
def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ...
class EdgeProtocol(Protocol):
id: str
state: NodeState
tail: str
head: str
source_handle: str
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
nodes: dict[str, NodeState] = Field(default_factory=dict)
edges: dict[str, NodeState] = Field(default_factory=dict)
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
start_at: float
total_tokens: int
node_run_steps: int
llm_usage: LLMUsage
outputs: dict[str, Any]
variable_pool: VariablePool
has_variable_pool: bool
ready_queue_dump: str | None
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
graph_node_states: dict[str, NodeState]
graph_edge_states: dict[str, NodeState]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components.
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
def __init__(
self,
*,
variable_pool: VariablePool,
start_at: float,
total_tokens: int = 0,
llm_usage: LLMUsage | None = None,
outputs: dict[str, object] | None = None,
node_run_steps: int = 0,
ready_queue: ReadyQueueProtocol | None = None,
graph_execution: GraphExecutionProtocol | None = None,
response_coordinator: ResponseStreamCoordinatorProtocol | None = None,
graph: GraphProtocol | None = None,
) -> None:
self._variable_pool = variable_pool
self._start_at = start_at
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = total_tokens
self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy()
self._outputs = deepcopy(outputs) if outputs is not None else {}
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = node_run_steps
self._graph: GraphProtocol | None = None
self._ready_queue = ready_queue
self._graph_execution = graph_execution
self._response_coordinator = response_coordinator
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Node and edges states needed to be restored into
# graph object.
#
# These two fields are non-None only when resuming from a snapshot.
# Once the graph is attached, these two fields will be set to None.
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
if graph is not None:
self.attach_graph(graph)
# ------------------------------------------------------------------
# Context binding helpers
# ------------------------------------------------------------------
def attach_graph(self, graph: GraphProtocol) -> None:
"""Attach the materialized graph to the runtime state."""
if self._graph is not None and self._graph is not graph:
raise ValueError("GraphRuntimeState already attached to a different graph instance")
self._graph = graph
if self._response_coordinator is None:
self._response_coordinator = self._build_response_coordinator(graph)
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
self._apply_pending_graph_state()
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
if graph is not None:
self.attach_graph(graph)
# Ensure collaborators are instantiated
_ = self.ready_queue
_ = self.graph_execution
if self._graph is not None:
_ = self.response_coordinator
# ------------------------------------------------------------------
# Primary collaborators
# ------------------------------------------------------------------
@property
def variable_pool(self) -> VariablePool:
return self._variable_pool
@property
def ready_queue(self) -> ReadyQueueProtocol:
if self._ready_queue is None:
self._ready_queue = self._build_ready_queue()
return self._ready_queue
@property
def graph_execution(self) -> GraphExecutionProtocol:
if self._graph_execution is None:
self._graph_execution = self._build_graph_execution()
return self._graph_execution
@property
def response_coordinator(self) -> ResponseStreamCoordinatorProtocol:
if self._response_coordinator is None:
if self._graph is None:
raise ValueError("Graph must be attached before accessing response coordinator")
self._response_coordinator = self._build_response_coordinator(self._graph)
return self._response_coordinator
# ------------------------------------------------------------------
# Scalar state
# ------------------------------------------------------------------
@property
def start_at(self) -> float:
return self._start_at
@start_at.setter
def start_at(self, value: float) -> None:
self._start_at = value
@property
def total_tokens(self) -> int:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: int) -> None:
if value < 0:
raise ValueError("total_tokens must be non-negative")
self._total_tokens = value
@property
def llm_usage(self) -> LLMUsage:
return self._llm_usage.model_copy()
@llm_usage.setter
def llm_usage(self, value: LLMUsage) -> None:
self._llm_usage = value.model_copy()
@property
def outputs(self) -> dict[str, Any]:
return deepcopy(self._outputs)
@outputs.setter
def outputs(self, value: dict[str, Any]) -> None:
self._outputs = deepcopy(value)
def set_output(self, key: str, value: object) -> None:
self._outputs[key] = deepcopy(value)
def get_output(self, key: str, default: object = None) -> object:
return deepcopy(self._outputs.get(key, default))
def update_outputs(self, updates: dict[str, object]) -> None:
for key, value in updates.items():
self._outputs[key] = deepcopy(value)
@property
def node_run_steps(self) -> int:
return self._node_run_steps
@node_run_steps.setter
def node_run_steps(self, value: int) -> None:
if value < 0:
raise ValueError("node_run_steps must be non-negative")
self._node_run_steps = value
def increment_node_run_steps(self) -> None:
self._node_run_steps += 1
def add_tokens(self, tokens: int) -> None:
if tokens < 0:
raise ValueError("tokens must be non-negative")
self._total_tokens += tokens
# ------------------------------------------------------------------
# Serialization
# ------------------------------------------------------------------
def dumps(self) -> str:
"""Serialize runtime state into a JSON string."""
snapshot: dict[str, Any] = {
"version": "1.0",
"start_at": self._start_at,
"total_tokens": self._total_tokens,
"node_run_steps": self._node_run_steps,
"llm_usage": self._llm_usage.model_dump(mode="json"),
"outputs": self.outputs,
"variable_pool": self.variable_pool.model_dump(mode="json"),
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
"deferred_nodes": list(self._deferred_nodes),
}
graph_state = self._snapshot_graph_state()
if graph_state is not None:
snapshot["graph_state"] = graph_state
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
return json.dumps(snapshot, default=pydantic_encoder)
@classmethod
def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState:
"""Restore runtime state from a serialized snapshot."""
snapshot = cls._parse_snapshot_payload(data)
state = cls(
variable_pool=snapshot.variable_pool,
start_at=snapshot.start_at,
total_tokens=snapshot.total_tokens,
llm_usage=snapshot.llm_usage,
outputs=snapshot.outputs,
node_run_steps=snapshot.node_run_steps,
)
state._apply_snapshot(snapshot)
return state
def loads(self, data: str | Mapping[str, Any]) -> None:
"""Restore runtime state from a serialized snapshot (legacy API)."""
snapshot = self._parse_snapshot_payload(data)
self._apply_snapshot(snapshot)
def register_paused_node(self, node_id: str) -> None:
"""Record a node that should resume when execution is continued."""
self._paused_nodes.add(node_id)
def get_paused_nodes(self) -> list[str]:
"""Retrieve the list of paused nodes without mutating internal state."""
return list(self._paused_nodes)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
nodes = list(self._paused_nodes)
self._paused_nodes.clear()
return nodes
def register_deferred_node(self, node_id: str) -> None:
"""Record a node that became ready during pause and should resume later."""
self._deferred_nodes.add(node_id)
def get_deferred_nodes(self) -> list[str]:
"""Retrieve deferred nodes without mutating internal state."""
return list(self._deferred_nodes)
def consume_deferred_nodes(self) -> list[str]:
"""Retrieve and clear deferred nodes awaiting resume."""
nodes = list(self._deferred_nodes)
self._deferred_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
def _build_ready_queue(self) -> ReadyQueueProtocol:
# Import lazily to avoid breaching architecture boundaries enforced by import-linter.
module = importlib.import_module("dify_graph.graph_engine.ready_queue")
in_memory_cls = module.InMemoryReadyQueue
return in_memory_cls()
def _build_graph_execution(self) -> GraphExecutionProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution")
graph_execution_cls = module.GraphExecution
workflow_id = self._pending_graph_execution_workflow_id or ""
self._pending_graph_execution_workflow_id = None
return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type]
def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol:
# Lazily import to keep the runtime domain decoupled from graph_engine modules.
module = importlib.import_module("dify_graph.graph_engine.response_coordinator")
coordinator_cls = module.ResponseStreamCoordinator
return coordinator_cls(variable_pool=self.variable_pool, graph=graph)
# ------------------------------------------------------------------
# Snapshot helpers
# ------------------------------------------------------------------
@classmethod
def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot:
payload: dict[str, Any]
if isinstance(data, str):
payload = json.loads(data)
else:
payload = dict(data)
version = payload.get("version")
if version != "1.0":
raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}")
start_at = float(payload.get("start_at", 0.0))
total_tokens = int(payload.get("total_tokens", 0))
if total_tokens < 0:
raise ValueError("total_tokens must be non-negative")
node_run_steps = int(payload.get("node_run_steps", 0))
if node_run_steps < 0:
raise ValueError("node_run_steps must be non-negative")
llm_usage_payload = payload.get("llm_usage", {})
llm_usage = LLMUsage.model_validate(llm_usage_payload)
outputs_payload = deepcopy(payload.get("outputs", {}))
variable_pool_payload = payload.get("variable_pool")
has_variable_pool = variable_pool_payload is not None
variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool()
ready_queue_payload = payload.get("ready_queue")
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
graph_state_payload = payload.get("graph_state", {}) or {}
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
return _GraphRuntimeStateSnapshot(
start_at=start_at,
total_tokens=total_tokens,
node_run_steps=node_run_steps,
llm_usage=llm_usage,
outputs=outputs_payload,
variable_pool=variable_pool,
has_variable_pool=has_variable_pool,
ready_queue_dump=ready_queue_payload,
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
graph_node_states=graph_node_states,
graph_edge_states=graph_edge_states,
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
self._start_at = snapshot.start_at
self._total_tokens = snapshot.total_tokens
self._node_run_steps = snapshot.node_run_steps
self._llm_usage = snapshot.llm_usage.model_copy()
self._outputs = deepcopy(snapshot.outputs)
if snapshot.has_variable_pool or self._variable_pool is None:
self._variable_pool = snapshot.variable_pool
self._restore_ready_queue(snapshot.ready_queue_dump)
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
self._pending_graph_node_states = snapshot.graph_node_states or None
self._pending_graph_edge_states = snapshot.graph_edge_states or None
self._apply_pending_graph_state()
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
self._ready_queue = self._build_ready_queue()
self._ready_queue.loads(payload)
else:
self._ready_queue = None
def _restore_graph_execution(self, payload: str | None) -> None:
self._graph_execution = None
self._pending_graph_execution_workflow_id = None
if payload is None:
return
try:
execution_payload = json.loads(payload)
self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id")
except (json.JSONDecodeError, TypeError, AttributeError):
self._pending_graph_execution_workflow_id = None
self.graph_execution.loads(payload)
def _restore_response_coordinator(self, payload: str | None) -> None:
if payload is None:
self._pending_response_coordinator_dump = None
self._response_coordinator = None
return
if self._graph is not None:
self.response_coordinator.loads(payload)
self._pending_response_coordinator_dump = None
return
self._pending_response_coordinator_dump = payload
self._response_coordinator = None
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
graph = self._graph
if graph is None:
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
return _GraphStateSnapshot()
return _GraphStateSnapshot(
nodes=self._pending_graph_node_states or {},
edges=self._pending_graph_edge_states or {},
)
nodes = graph.nodes
edges = graph.edges
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
return _GraphStateSnapshot()
node_states = {}
for node_id, node in nodes.items():
if not isinstance(node_id, str):
continue
node_states[node_id] = node.state
edge_states = {}
for edge_id, edge in edges.items():
if not isinstance(edge_id, str):
continue
edge_states[edge_id] = edge.state
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
def _apply_pending_graph_state(self) -> None:
if self._graph is None:
return
if self._pending_graph_node_states:
for node_id, state in self._pending_graph_node_states.items():
node = self._graph.nodes.get(node_id)
if node is None:
continue
node.state = state
if self._pending_graph_edge_states:
for edge_id, state in self._pending_graph_edge_states.items():
edge = self._graph.edges.get(edge_id)
if edge is None:
continue
edge.state = state
self._pending_graph_node_states = None
self._pending_graph_edge_states = None
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
if not isinstance(payload, Mapping):
return {}
raw_map = payload.get(key, {})
if not isinstance(raw_map, Mapping):
return {}
result: dict[str, NodeState] = {}
for node_id, raw_state in raw_map.items():
if not isinstance(node_id, str):
continue
try:
result[node_id] = NodeState(str(raw_state))
except ValueError:
continue
return result