mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 01:18:05 +08:00
fix(api): ensure node and edge states are properly persisted while pausing
This commit is contained in:
@ -7,9 +7,11 @@ from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import NodeState
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -104,14 +106,33 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeProtocol(Protocol):
|
||||
"""Structural interface for graph nodes."""
|
||||
|
||||
id: str
|
||||
state: NodeState
|
||||
|
||||
|
||||
class EdgeProtocol(Protocol):
|
||||
id: str
|
||||
state: NodeState
|
||||
|
||||
|
||||
class GraphProtocol(Protocol):
|
||||
"""Structural interface required from graph instances attached to the runtime state."""
|
||||
|
||||
nodes: Mapping[str, object]
|
||||
edges: Mapping[str, object]
|
||||
root_node: object
|
||||
nodes: Mapping[str, NodeProtocol]
|
||||
edges: Mapping[str, EdgeProtocol]
|
||||
root_node: NodeProtocol
|
||||
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
nodes: dict[str, NodeState] = Field(default_factory=dict)
|
||||
edges: dict[str, NodeState] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@ -130,6 +151,8 @@ class _GraphRuntimeStateSnapshot:
|
||||
response_coordinator_dump: str | None
|
||||
paused_nodes: tuple[str, ...]
|
||||
deferred_nodes: tuple[str, ...]
|
||||
graph_node_states: dict[str, NodeState]
|
||||
graph_edge_states: dict[str, NodeState]
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
@ -180,6 +203,14 @@ class GraphRuntimeState:
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
#
|
||||
# These two fields are non-None only when resuming from a snapshot.
|
||||
# Once the graph is attached, these two fields will be set to None.
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
@ -199,6 +230,7 @@ class GraphRuntimeState:
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
@ -323,6 +355,10 @@ class GraphRuntimeState:
|
||||
"deferred_nodes": list(self._deferred_nodes),
|
||||
}
|
||||
|
||||
graph_state = self._snapshot_graph_state()
|
||||
if graph_state is not None:
|
||||
snapshot["graph_state"] = graph_state
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
@ -447,6 +483,9 @@ class GraphRuntimeState:
|
||||
response_payload = payload.get("response_coordinator")
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
deferred_nodes_payload = payload.get("deferred_nodes", [])
|
||||
graph_state_payload = payload.get("graph_state", {}) or {}
|
||||
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
|
||||
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
|
||||
|
||||
return _GraphRuntimeStateSnapshot(
|
||||
start_at=start_at,
|
||||
@ -461,6 +500,8 @@ class GraphRuntimeState:
|
||||
response_coordinator_dump=response_payload,
|
||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
|
||||
graph_node_states=graph_node_states,
|
||||
graph_edge_states=graph_edge_states,
|
||||
)
|
||||
|
||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||
@ -477,6 +518,9 @@ class GraphRuntimeState:
|
||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||
self._paused_nodes = set(snapshot.paused_nodes)
|
||||
self._deferred_nodes = set(snapshot.deferred_nodes)
|
||||
self._pending_graph_node_states = snapshot.graph_node_states or None
|
||||
self._pending_graph_edge_states = snapshot.graph_edge_states or None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||
if payload is not None:
|
||||
@ -513,3 +557,68 @@ class GraphRuntimeState:
|
||||
|
||||
self._pending_response_coordinator_dump = payload
|
||||
self._response_coordinator = None
|
||||
|
||||
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
|
||||
graph = self._graph
|
||||
if graph is None:
|
||||
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
|
||||
return _GraphStateSnapshot()
|
||||
return _GraphStateSnapshot(
|
||||
nodes=self._pending_graph_node_states or {},
|
||||
edges=self._pending_graph_edge_states or {},
|
||||
)
|
||||
|
||||
nodes = graph.nodes
|
||||
edges = graph.edges
|
||||
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
|
||||
return _GraphStateSnapshot()
|
||||
|
||||
node_states = {}
|
||||
for node_id, node in nodes.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
node_states[node_id] = node.state
|
||||
|
||||
edge_states = {}
|
||||
for edge_id, edge in edges.items():
|
||||
if not isinstance(edge_id, str):
|
||||
continue
|
||||
edge_states[edge_id] = edge.state
|
||||
|
||||
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
|
||||
|
||||
def _apply_pending_graph_state(self) -> None:
|
||||
if self._graph is None:
|
||||
return
|
||||
if self._pending_graph_node_states:
|
||||
for node_id, state in self._pending_graph_node_states.items():
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
node.state = state
|
||||
if self._pending_graph_edge_states:
|
||||
for edge_id, state in self._pending_graph_edge_states.items():
|
||||
edge = self._graph.edges.get(edge_id)
|
||||
if edge is None:
|
||||
continue
|
||||
edge.state = state
|
||||
|
||||
self._pending_graph_node_states = None
|
||||
self._pending_graph_edge_states = None
|
||||
|
||||
|
||||
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return {}
|
||||
raw_map = payload.get(key, {})
|
||||
if not isinstance(raw_map, Mapping):
|
||||
return {}
|
||||
result: dict[str, NodeState] = {}
|
||||
for node_id, raw_state in raw_map.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
try:
|
||||
result[node_id] = NodeState(str(raw_state))
|
||||
except ValueError:
|
||||
continue
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user