refactor: move workflow package to dify_graph (#32844)

This commit is contained in:
-LAN-
2026-03-02 18:42:30 +08:00
committed by GitHub
parent 9c33923985
commit c917838f9c
613 changed files with 2008 additions and 2012 deletions

View File

@ -0,0 +1,14 @@
"""
Graph traversal subsystem for graph engine.
This package handles graph navigation, edge processing,
and skip propagation logic.
"""
from .edge_processor import EdgeProcessor
from .skip_propagator import SkipPropagator
__all__ = [
"EdgeProcessor",
"SkipPropagator",
]

View File

@ -0,0 +1,201 @@
"""
Edge processing logic for graph traversal.
"""
from collections.abc import Sequence
from typing import TYPE_CHECKING, final
from dify_graph.enums import NodeExecutionType
from dify_graph.graph import Edge, Graph
from dify_graph.graph_events import NodeRunStreamChunkEvent
from ..graph_state_manager import GraphStateManager
from ..response_coordinator import ResponseStreamCoordinator
if TYPE_CHECKING:
from .skip_propagator import SkipPropagator
@final
class EdgeProcessor:
"""
Processes edges during graph execution.
This handles marking edges as taken or skipped, notifying
the response coordinator, triggering downstream node execution,
and managing branch node logic.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
response_coordinator: ResponseStreamCoordinator,
skip_propagator: "SkipPropagator",
) -> None:
"""
Initialize the edge processor.
Args:
graph: The workflow graph
state_manager: Unified state manager
response_coordinator: Response stream coordinator
skip_propagator: Propagator for skip states
"""
self._graph = graph
self._state_manager = state_manager
self._response_coordinator = response_coordinator
self._skip_propagator = skip_propagator
def process_node_success(
self, node_id: str, selected_handle: str | None = None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges after a node succeeds.
Args:
node_id: The ID of the succeeded node
selected_handle: For branch nodes, the selected edge handle
Returns:
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
"""
node = self._graph.nodes[node_id]
if node.execution_type == NodeExecutionType.BRANCH:
return self._process_branch_node_edges(node_id, selected_handle)
else:
return self._process_non_branch_node_edges(node_id)
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for non-branch nodes (mark all as TAKEN).
Args:
node_id: The ID of the succeeded node
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
"""
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_branch_node_edges(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Process edges for branch nodes.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected edge
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no edge was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} did not select any edge")
ready_nodes: list[str] = []
all_streaming_events: list[NodeRunStreamChunkEvent] = []
# Categorize edges
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Process unselected edges first (mark as skipped)
for edge in unselected_edges:
self._process_skipped_edge(edge)
# Process selected edges
for edge in selected_edges:
nodes, events = self._process_taken_edge(edge)
ready_nodes.extend(nodes)
all_streaming_events.extend(events)
return ready_nodes, all_streaming_events
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Mark edge as taken and check downstream node.
Args:
edge: The edge to process
Returns:
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
"""
# Mark edge as taken
self._state_manager.mark_edge_taken(edge.id)
# Notify response coordinator and get streaming events
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
# Check if downstream node is ready
ready_nodes: list[str] = []
if self._state_manager.is_node_ready(edge.head):
ready_nodes.append(edge.head)
return ready_nodes, streaming_events
def _process_skipped_edge(self, edge: Edge) -> None:
"""
Mark edge as skipped.
Args:
edge: The edge to skip
"""
self._state_manager.mark_edge_skipped(edge.id)
def handle_branch_completion(
self, node_id: str, selected_handle: str | None
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
"""
Handle completion of a branch node.
Args:
node_id: The ID of the branch node
selected_handle: The handle of the selected branch
Returns:
Tuple of (list of downstream nodes ready for execution, list of streaming events)
Raises:
ValueError: If no branch was selected
"""
if not selected_handle:
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
# Categorize edges into selected and unselected
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
# Skip all unselected paths
self._skip_propagator.skip_branch_paths(unselected_edges)
# Process selected edges and get ready nodes and streaming events
return self.process_node_success(node_id, selected_handle)
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
"""
Validate that a branch selection is valid.
Args:
node_id: The ID of the branch node
selected_handle: The handle to validate
Returns:
True if the selection is valid
"""
outgoing_edges = self._graph.get_outgoing_edges(node_id)
valid_handles = {edge.source_handle for edge in outgoing_edges}
return selected_handle in valid_handles

View File

@ -0,0 +1,96 @@
"""
Skip state propagation through the graph.
"""
from collections.abc import Sequence
from typing import final
from dify_graph.graph import Edge, Graph
from ..graph_state_manager import GraphStateManager
@final
class SkipPropagator:
"""
Propagates skip states through the graph.
When a node is skipped, this ensures all downstream nodes
that depend solely on it are also skipped.
"""
def __init__(
self,
graph: Graph,
state_manager: GraphStateManager,
) -> None:
"""
Initialize the skip propagator.
Args:
graph: The workflow graph
state_manager: Unified state manager
"""
self._graph = graph
self._state_manager = state_manager
def propagate_skip_from_edge(self, edge_id: str) -> None:
"""
Recursively propagate skip state from a skipped edge.
Rules:
- If a node has any UNKNOWN incoming edges, stop processing
- If all incoming edges are SKIPPED, skip the node and its edges
- If any incoming edge is TAKEN, the node may still execute
Args:
edge_id: The ID of the skipped edge to start from
"""
downstream_node_id = self._graph.edges[edge_id].head
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
# Analyze edge states
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
# Stop if there are unknown edges (not yet processed)
if edge_states["has_unknown"]:
return
# If any edge is taken, node may still execute
if edge_states["has_taken"]:
# Enqueue node
self._state_manager.enqueue_node(downstream_node_id)
self._state_manager.start_execution(downstream_node_id)
return
# All edges are skipped, propagate skip to this node
if edge_states["all_skipped"]:
self._propagate_skip_to_node(downstream_node_id)
def _propagate_skip_to_node(self, node_id: str) -> None:
"""
Mark a node and all its outgoing edges as skipped.
Args:
node_id: The ID of the node to skip
"""
# Mark node as skipped
self._state_manager.mark_node_skipped(node_id)
# Mark all outgoing edges as skipped and propagate
outgoing_edges = self._graph.get_outgoing_edges(node_id)
for edge in outgoing_edges:
self._state_manager.mark_edge_skipped(edge.id)
# Recursively propagate skip
self.propagate_skip_from_edge(edge.id)
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
"""
Skip all paths from unselected branch edges.
Args:
unselected_edges: List of edges not taken by the branch
"""
for edge in unselected_edges:
self._state_manager.mark_edge_skipped(edge.id)
self.propagate_skip_from_edge(edge.id)