fix(api): ensure node and edge states are properly persisted while pausing

This commit is contained in:
QuantumGhost
2026-01-28 08:17:14 +08:00
parent 19e3d07baf
commit 966a87b81a
3 changed files with 571 additions and 4 deletions

View File

@ -7,9 +7,11 @@ from copy import deepcopy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeState
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
@ -104,14 +106,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
...
class NodeProtocol(Protocol):
"""Structural interface for graph nodes."""
id: str
state: NodeState
class EdgeProtocol(Protocol):
id: str
state: NodeState
class GraphProtocol(Protocol):
"""Structural interface required from graph instances attached to the runtime state."""
nodes: Mapping[str, object]
edges: Mapping[str, object]
root_node: object
nodes: Mapping[str, NodeProtocol]
edges: Mapping[str, EdgeProtocol]
root_node: NodeProtocol
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
nodes: dict[str, NodeState] = Field(default_factory=dict)
edges: dict[str, NodeState] = Field(default_factory=dict)
@dataclass(slots=True)
@ -130,6 +151,8 @@ class _GraphRuntimeStateSnapshot:
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
graph_node_states: dict[str, NodeState]
graph_edge_states: dict[str, NodeState]
class GraphRuntimeState:
@ -180,6 +203,14 @@ class GraphRuntimeState:
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Node and edges states needed to be restored into
# graph object.
#
# These two fields are non-None only when resuming from a snapshot.
# Once the graph is attached, these two fields will be set to None.
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
if graph is not None:
self.attach_graph(graph)
@ -199,6 +230,7 @@ class GraphRuntimeState:
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
self._apply_pending_graph_state()
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
@ -323,6 +355,10 @@ class GraphRuntimeState:
"deferred_nodes": list(self._deferred_nodes),
}
graph_state = self._snapshot_graph_state()
if graph_state is not None:
snapshot["graph_state"] = graph_state
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
@ -447,6 +483,9 @@ class GraphRuntimeState:
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
graph_state_payload = payload.get("graph_state", {}) or {}
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
return _GraphRuntimeStateSnapshot(
start_at=start_at,
@ -461,6 +500,8 @@ class GraphRuntimeState:
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
graph_node_states=graph_node_states,
graph_edge_states=graph_edge_states,
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
@ -477,6 +518,9 @@ class GraphRuntimeState:
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
self._pending_graph_node_states = snapshot.graph_node_states or None
self._pending_graph_edge_states = snapshot.graph_edge_states or None
self._apply_pending_graph_state()
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
@ -513,3 +557,68 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump = payload
self._response_coordinator = None
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
graph = self._graph
if graph is None:
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
return _GraphStateSnapshot()
return _GraphStateSnapshot(
nodes=self._pending_graph_node_states or {},
edges=self._pending_graph_edge_states or {},
)
nodes = graph.nodes
edges = graph.edges
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
return _GraphStateSnapshot()
node_states = {}
for node_id, node in nodes.items():
if not isinstance(node_id, str):
continue
node_states[node_id] = node.state
edge_states = {}
for edge_id, edge in edges.items():
if not isinstance(edge_id, str):
continue
edge_states[edge_id] = edge.state
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
def _apply_pending_graph_state(self) -> None:
if self._graph is None:
return
if self._pending_graph_node_states:
for node_id, state in self._pending_graph_node_states.items():
node = self._graph.nodes.get(node_id)
if node is None:
continue
node.state = state
if self._pending_graph_edge_states:
for edge_id, state in self._pending_graph_edge_states.items():
edge = self._graph.edges.get(edge_id)
if edge is None:
continue
edge.state = state
self._pending_graph_node_states = None
self._pending_graph_edge_states = None
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
if not isinstance(payload, Mapping):
return {}
raw_map = payload.get(key, {})
if not isinstance(raw_map, Mapping):
return {}
result: dict[str, NodeState] = {}
for node_id, raw_state in raw_map.items():
if not isinstance(node_id, str):
continue
try:
result[node_id] = NodeState(str(raw_state))
except ValueError:
continue
return result