mirror of
https://github.com/langgenius/dify.git
synced 2026-04-29 23:18:05 +08:00
refactor: move workflow package to dify_graph (#32844)
This commit is contained in:
14
api/dify_graph/graph_engine/graph_traversal/__init__.py
Normal file
14
api/dify_graph/graph_engine/graph_traversal/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""
|
||||
Graph traversal subsystem for graph engine.
|
||||
|
||||
This package handles graph navigation, edge processing,
|
||||
and skip propagation logic.
|
||||
"""
|
||||
|
||||
from .edge_processor import EdgeProcessor
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
__all__ = [
|
||||
"EdgeProcessor",
|
||||
"SkipPropagator",
|
||||
]
|
||||
201
api/dify_graph/graph_engine/graph_traversal/edge_processor.py
Normal file
201
api/dify_graph/graph_engine/graph_traversal/edge_processor.py
Normal file
@ -0,0 +1,201 @@
|
||||
"""
|
||||
Edge processing logic for graph traversal.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, final
|
||||
|
||||
from dify_graph.enums import NodeExecutionType
|
||||
from dify_graph.graph import Edge, Graph
|
||||
from dify_graph.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
from ..response_coordinator import ResponseStreamCoordinator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skip_propagator import SkipPropagator
|
||||
|
||||
|
||||
@final
|
||||
class EdgeProcessor:
|
||||
"""
|
||||
Processes edges during graph execution.
|
||||
|
||||
This handles marking edges as taken or skipped, notifying
|
||||
the response coordinator, triggering downstream node execution,
|
||||
and managing branch node logic.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
response_coordinator: ResponseStreamCoordinator,
|
||||
skip_propagator: "SkipPropagator",
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the edge processor.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
response_coordinator: Response stream coordinator
|
||||
skip_propagator: Propagator for skip states
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
self._response_coordinator = response_coordinator
|
||||
self._skip_propagator = skip_propagator
|
||||
|
||||
def process_node_success(
|
||||
self, node_id: str, selected_handle: str | None = None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges after a node succeeds.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
selected_handle: For branch nodes, the selected edge handle
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream node IDs that are now ready, list of streaming events)
|
||||
"""
|
||||
node = self._graph.nodes[node_id]
|
||||
|
||||
if node.execution_type == NodeExecutionType.BRANCH:
|
||||
return self._process_branch_node_edges(node_id, selected_handle)
|
||||
else:
|
||||
return self._process_non_branch_node_edges(node_id)
|
||||
|
||||
def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for non-branch nodes (mark all as TAKEN).
|
||||
|
||||
Args:
|
||||
node_id: The ID of the succeeded node
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
"""
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
|
||||
for edge in outgoing_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_branch_node_edges(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Process edges for branch nodes.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected edge
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no edge was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} did not select any edge")
|
||||
|
||||
ready_nodes: list[str] = []
|
||||
all_streaming_events: list[NodeRunStreamChunkEvent] = []
|
||||
|
||||
# Categorize edges
|
||||
selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Process unselected edges first (mark as skipped)
|
||||
for edge in unselected_edges:
|
||||
self._process_skipped_edge(edge)
|
||||
|
||||
# Process selected edges
|
||||
for edge in selected_edges:
|
||||
nodes, events = self._process_taken_edge(edge)
|
||||
ready_nodes.extend(nodes)
|
||||
all_streaming_events.extend(events)
|
||||
|
||||
return ready_nodes, all_streaming_events
|
||||
|
||||
def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Mark edge as taken and check downstream node.
|
||||
|
||||
Args:
|
||||
edge: The edge to process
|
||||
|
||||
Returns:
|
||||
Tuple of (list containing downstream node ID if it's ready, list of streaming events)
|
||||
"""
|
||||
# Mark edge as taken
|
||||
self._state_manager.mark_edge_taken(edge.id)
|
||||
|
||||
# Notify response coordinator and get streaming events
|
||||
streaming_events = self._response_coordinator.on_edge_taken(edge.id)
|
||||
|
||||
# Check if downstream node is ready
|
||||
ready_nodes: list[str] = []
|
||||
if self._state_manager.is_node_ready(edge.head):
|
||||
ready_nodes.append(edge.head)
|
||||
|
||||
return ready_nodes, streaming_events
|
||||
|
||||
def _process_skipped_edge(self, edge: Edge) -> None:
|
||||
"""
|
||||
Mark edge as skipped.
|
||||
|
||||
Args:
|
||||
edge: The edge to skip
|
||||
"""
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
|
||||
def handle_branch_completion(
|
||||
self, node_id: str, selected_handle: str | None
|
||||
) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]:
|
||||
"""
|
||||
Handle completion of a branch node.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle of the selected branch
|
||||
|
||||
Returns:
|
||||
Tuple of (list of downstream nodes ready for execution, list of streaming events)
|
||||
|
||||
Raises:
|
||||
ValueError: If no branch was selected
|
||||
"""
|
||||
if not selected_handle:
|
||||
raise ValueError(f"Branch node {node_id} completed without selecting a branch")
|
||||
|
||||
# Categorize edges into selected and unselected
|
||||
_, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle)
|
||||
|
||||
# Skip all unselected paths
|
||||
self._skip_propagator.skip_branch_paths(unselected_edges)
|
||||
|
||||
# Process selected edges and get ready nodes and streaming events
|
||||
return self.process_node_success(node_id, selected_handle)
|
||||
|
||||
def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool:
|
||||
"""
|
||||
Validate that a branch selection is valid.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the branch node
|
||||
selected_handle: The handle to validate
|
||||
|
||||
Returns:
|
||||
True if the selection is valid
|
||||
"""
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
valid_handles = {edge.source_handle for edge in outgoing_edges}
|
||||
return selected_handle in valid_handles
|
||||
@ -0,0 +1,96 @@
|
||||
"""
|
||||
Skip state propagation through the graph.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from dify_graph.graph import Edge, Graph
|
||||
|
||||
from ..graph_state_manager import GraphStateManager
|
||||
|
||||
|
||||
@final
|
||||
class SkipPropagator:
|
||||
"""
|
||||
Propagates skip states through the graph.
|
||||
|
||||
When a node is skipped, this ensures all downstream nodes
|
||||
that depend solely on it are also skipped.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph: Graph,
|
||||
state_manager: GraphStateManager,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the skip propagator.
|
||||
|
||||
Args:
|
||||
graph: The workflow graph
|
||||
state_manager: Unified state manager
|
||||
"""
|
||||
self._graph = graph
|
||||
self._state_manager = state_manager
|
||||
|
||||
def propagate_skip_from_edge(self, edge_id: str) -> None:
|
||||
"""
|
||||
Recursively propagate skip state from a skipped edge.
|
||||
|
||||
Rules:
|
||||
- If a node has any UNKNOWN incoming edges, stop processing
|
||||
- If all incoming edges are SKIPPED, skip the node and its edges
|
||||
- If any incoming edge is TAKEN, the node may still execute
|
||||
|
||||
Args:
|
||||
edge_id: The ID of the skipped edge to start from
|
||||
"""
|
||||
downstream_node_id = self._graph.edges[edge_id].head
|
||||
incoming_edges = self._graph.get_incoming_edges(downstream_node_id)
|
||||
|
||||
# Analyze edge states
|
||||
edge_states = self._state_manager.analyze_edge_states(incoming_edges)
|
||||
|
||||
# Stop if there are unknown edges (not yet processed)
|
||||
if edge_states["has_unknown"]:
|
||||
return
|
||||
|
||||
# If any edge is taken, node may still execute
|
||||
if edge_states["has_taken"]:
|
||||
# Enqueue node
|
||||
self._state_manager.enqueue_node(downstream_node_id)
|
||||
self._state_manager.start_execution(downstream_node_id)
|
||||
return
|
||||
|
||||
# All edges are skipped, propagate skip to this node
|
||||
if edge_states["all_skipped"]:
|
||||
self._propagate_skip_to_node(downstream_node_id)
|
||||
|
||||
def _propagate_skip_to_node(self, node_id: str) -> None:
|
||||
"""
|
||||
Mark a node and all its outgoing edges as skipped.
|
||||
|
||||
Args:
|
||||
node_id: The ID of the node to skip
|
||||
"""
|
||||
# Mark node as skipped
|
||||
self._state_manager.mark_node_skipped(node_id)
|
||||
|
||||
# Mark all outgoing edges as skipped and propagate
|
||||
outgoing_edges = self._graph.get_outgoing_edges(node_id)
|
||||
for edge in outgoing_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
# Recursively propagate skip
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
|
||||
def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None:
|
||||
"""
|
||||
Skip all paths from unselected branch edges.
|
||||
|
||||
Args:
|
||||
unselected_edges: List of edges not taken by the branch
|
||||
"""
|
||||
for edge in unselected_edges:
|
||||
self._state_manager.mark_edge_skipped(edge.id)
|
||||
self.propagate_skip_from_edge(edge.id)
|
||||
Reference in New Issue
Block a user