mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
291 lines
8.3 KiB
Python
291 lines
8.3 KiB
Python
"""
|
|
Graph state manager that combines node, edge, and execution tracking.
|
|
"""
|
|
|
|
import threading
|
|
from collections.abc import Sequence
|
|
from typing import TypedDict, final
|
|
|
|
from dify_graph.enums import NodeState
|
|
from dify_graph.graph import Edge, Graph
|
|
|
|
from .ready_queue import ReadyQueue
|
|
|
|
|
|
class EdgeStateAnalysis(TypedDict):
|
|
"""Analysis result for edge states."""
|
|
|
|
has_unknown: bool
|
|
has_taken: bool
|
|
all_skipped: bool
|
|
|
|
|
|
@final
|
|
class GraphStateManager:
|
|
def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None:
|
|
"""
|
|
Initialize the state manager.
|
|
|
|
Args:
|
|
graph: The workflow graph
|
|
ready_queue: Queue for nodes ready to execute
|
|
"""
|
|
self._graph = graph
|
|
self._ready_queue = ready_queue
|
|
self._lock = threading.RLock()
|
|
|
|
# Execution tracking state
|
|
self._executing_nodes: set[str] = set()
|
|
|
|
# ============= Node State Operations =============
|
|
|
|
def enqueue_node(self, node_id: str) -> None:
|
|
"""
|
|
Mark a node as TAKEN and add it to the ready queue.
|
|
|
|
This combines the state transition and enqueueing operations
|
|
that always occur together when preparing a node for execution.
|
|
|
|
Args:
|
|
node_id: The ID of the node to enqueue
|
|
"""
|
|
with self._lock:
|
|
self._graph.nodes[node_id].state = NodeState.TAKEN
|
|
self._ready_queue.put(node_id)
|
|
|
|
def mark_node_skipped(self, node_id: str) -> None:
|
|
"""
|
|
Mark a node as SKIPPED.
|
|
|
|
Args:
|
|
node_id: The ID of the node to skip
|
|
"""
|
|
with self._lock:
|
|
self._graph.nodes[node_id].state = NodeState.SKIPPED
|
|
|
|
def is_node_ready(self, node_id: str) -> bool:
|
|
"""
|
|
Check if a node is ready to be executed.
|
|
|
|
A node is ready when all its incoming edges from taken branches
|
|
have been satisfied.
|
|
|
|
Args:
|
|
node_id: The ID of the node to check
|
|
|
|
Returns:
|
|
True if the node is ready for execution
|
|
"""
|
|
with self._lock:
|
|
# Get all incoming edges to this node
|
|
incoming_edges = self._graph.get_incoming_edges(node_id)
|
|
|
|
# If no incoming edges, node is always ready
|
|
if not incoming_edges:
|
|
return True
|
|
|
|
# If any edge is UNKNOWN, node is not ready
|
|
if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges):
|
|
return False
|
|
|
|
# Node is ready if at least one edge is TAKEN
|
|
return any(edge.state == NodeState.TAKEN for edge in incoming_edges)
|
|
|
|
def get_node_state(self, node_id: str) -> NodeState:
|
|
"""
|
|
Get the current state of a node.
|
|
|
|
Args:
|
|
node_id: The ID of the node
|
|
|
|
Returns:
|
|
The current node state
|
|
"""
|
|
with self._lock:
|
|
return self._graph.nodes[node_id].state
|
|
|
|
# ============= Edge State Operations =============
|
|
|
|
def mark_edge_taken(self, edge_id: str) -> None:
|
|
"""
|
|
Mark an edge as TAKEN.
|
|
|
|
Args:
|
|
edge_id: The ID of the edge to mark
|
|
"""
|
|
with self._lock:
|
|
self._graph.edges[edge_id].state = NodeState.TAKEN
|
|
|
|
def mark_edge_skipped(self, edge_id: str) -> None:
|
|
"""
|
|
Mark an edge as SKIPPED.
|
|
|
|
Args:
|
|
edge_id: The ID of the edge to mark
|
|
"""
|
|
with self._lock:
|
|
self._graph.edges[edge_id].state = NodeState.SKIPPED
|
|
|
|
def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis:
|
|
"""
|
|
Analyze the states of edges and return summary flags.
|
|
|
|
Args:
|
|
edges: List of edges to analyze
|
|
|
|
Returns:
|
|
Analysis result with state flags
|
|
"""
|
|
with self._lock:
|
|
states = {edge.state for edge in edges}
|
|
|
|
return EdgeStateAnalysis(
|
|
has_unknown=NodeState.UNKNOWN in states,
|
|
has_taken=NodeState.TAKEN in states,
|
|
all_skipped=states == {NodeState.SKIPPED} if states else True,
|
|
)
|
|
|
|
def get_edge_state(self, edge_id: str) -> NodeState:
|
|
"""
|
|
Get the current state of an edge.
|
|
|
|
Args:
|
|
edge_id: The ID of the edge
|
|
|
|
Returns:
|
|
The current edge state
|
|
"""
|
|
with self._lock:
|
|
return self._graph.edges[edge_id].state
|
|
|
|
def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]:
|
|
"""
|
|
Categorize branch edges into selected and unselected.
|
|
|
|
Args:
|
|
node_id: The ID of the branch node
|
|
selected_handle: The handle of the selected edge
|
|
|
|
Returns:
|
|
A tuple of (selected_edges, unselected_edges)
|
|
"""
|
|
with self._lock:
|
|
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
|
selected_edges: list[Edge] = []
|
|
unselected_edges: list[Edge] = []
|
|
|
|
for edge in outgoing_edges:
|
|
if edge.source_handle == selected_handle:
|
|
selected_edges.append(edge)
|
|
else:
|
|
unselected_edges.append(edge)
|
|
|
|
return selected_edges, unselected_edges
|
|
|
|
# ============= Execution Tracking Operations =============
|
|
|
|
def start_execution(self, node_id: str) -> None:
|
|
"""
|
|
Mark a node as executing.
|
|
|
|
Args:
|
|
node_id: The ID of the node starting execution
|
|
"""
|
|
with self._lock:
|
|
self._executing_nodes.add(node_id)
|
|
|
|
def finish_execution(self, node_id: str) -> None:
|
|
"""
|
|
Mark a node as no longer executing.
|
|
|
|
Args:
|
|
node_id: The ID of the node finishing execution
|
|
"""
|
|
with self._lock:
|
|
self._executing_nodes.discard(node_id)
|
|
|
|
def is_executing(self, node_id: str) -> bool:
|
|
"""
|
|
Check if a node is currently executing.
|
|
|
|
Args:
|
|
node_id: The ID of the node to check
|
|
|
|
Returns:
|
|
True if the node is executing
|
|
"""
|
|
with self._lock:
|
|
return node_id in self._executing_nodes
|
|
|
|
def get_executing_count(self) -> int:
|
|
"""
|
|
Get the count of currently executing nodes.
|
|
|
|
Returns:
|
|
Number of executing nodes
|
|
"""
|
|
# This count is a best-effort snapshot and can change concurrently.
|
|
# Only use it for pause-drain checks where scheduling is already frozen.
|
|
with self._lock:
|
|
return len(self._executing_nodes)
|
|
|
|
def get_executing_nodes(self) -> set[str]:
|
|
"""
|
|
Get a copy of the set of executing node IDs.
|
|
|
|
Returns:
|
|
Set of node IDs currently executing
|
|
"""
|
|
with self._lock:
|
|
return self._executing_nodes.copy()
|
|
|
|
def clear_executing(self) -> None:
|
|
"""Clear all executing nodes."""
|
|
with self._lock:
|
|
self._executing_nodes.clear()
|
|
|
|
# ============= Composite Operations =============
|
|
|
|
def is_execution_complete(self) -> bool:
|
|
"""
|
|
Check if graph execution is complete.
|
|
|
|
Execution is complete when:
|
|
- Ready queue is empty
|
|
- No nodes are executing
|
|
|
|
Returns:
|
|
True if execution is complete
|
|
"""
|
|
with self._lock:
|
|
return self._ready_queue.empty() and len(self._executing_nodes) == 0
|
|
|
|
def get_queue_depth(self) -> int:
|
|
"""
|
|
Get the current depth of the ready queue.
|
|
|
|
Returns:
|
|
Number of nodes in the ready queue
|
|
"""
|
|
return self._ready_queue.qsize()
|
|
|
|
def get_execution_stats(self) -> dict[str, int]:
|
|
"""
|
|
Get execution statistics.
|
|
|
|
Returns:
|
|
Dictionary with execution statistics
|
|
"""
|
|
with self._lock:
|
|
taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN)
|
|
skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED)
|
|
unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN)
|
|
|
|
return {
|
|
"queue_depth": self._ready_queue.qsize(),
|
|
"executing": len(self._executing_nodes),
|
|
"taken_nodes": taken_nodes,
|
|
"skipped_nodes": skipped_nodes,
|
|
"unknown_nodes": unknown_nodes,
|
|
}
|