mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
Merge branch 'feat/queue-based-graph-engine' into feat/rag-2
This commit is contained in:
@ -1,9 +1,9 @@
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, Protocol, cast
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
from .edge import Edge
|
||||
@ -36,10 +36,10 @@ class Graph:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
nodes: Optional[dict[str, Node]] = None,
|
||||
edges: Optional[dict[str, Edge]] = None,
|
||||
in_edges: Optional[dict[str, list[str]]] = None,
|
||||
out_edges: Optional[dict[str, list[str]]] = None,
|
||||
nodes: dict[str, Node] | None = None,
|
||||
edges: dict[str, Edge] | None = None,
|
||||
in_edges: dict[str, list[str]] | None = None,
|
||||
out_edges: dict[str, list[str]] | None = None,
|
||||
root_node: Node,
|
||||
):
|
||||
"""
|
||||
@ -81,7 +81,7 @@ class Graph:
|
||||
cls,
|
||||
node_configs_map: dict[str, dict[str, Any]],
|
||||
edge_configs: list[dict[str, Any]],
|
||||
root_node_id: Optional[str] = None,
|
||||
root_node_id: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find the root node ID if not specified.
|
||||
@ -186,13 +186,79 @@ class Graph:
|
||||
|
||||
return nodes
|
||||
|
||||
@classmethod
|
||||
def _mark_inactive_root_branches(
|
||||
cls,
|
||||
nodes: dict[str, Node],
|
||||
edges: dict[str, Edge],
|
||||
in_edges: dict[str, list[str]],
|
||||
out_edges: dict[str, list[str]],
|
||||
active_root_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Mark nodes and edges from inactive root branches as skipped.
|
||||
|
||||
Algorithm:
|
||||
1. Mark inactive root nodes as skipped
|
||||
2. For skipped nodes, mark all their outgoing edges as skipped
|
||||
3. For each edge marked as skipped, check its target node:
|
||||
- If ALL incoming edges are skipped, mark the node as skipped
|
||||
- Otherwise, leave the node state unchanged
|
||||
|
||||
:param nodes: mapping of node ID to node instance
|
||||
:param edges: mapping of edge ID to edge instance
|
||||
:param in_edges: mapping of node ID to incoming edge IDs
|
||||
:param out_edges: mapping of node ID to outgoing edge IDs
|
||||
:param active_root_id: ID of the active root node
|
||||
"""
|
||||
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
|
||||
top_level_roots: list[str] = [
|
||||
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
|
||||
]
|
||||
|
||||
# If there's only one root or the active root is not a top-level root, no marking needed
|
||||
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
|
||||
return
|
||||
|
||||
# Mark inactive root nodes as skipped
|
||||
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
|
||||
for root_id in inactive_roots:
|
||||
if root_id in nodes:
|
||||
nodes[root_id].state = NodeState.SKIPPED
|
||||
|
||||
# Recursively mark downstream nodes and edges
|
||||
def mark_downstream(node_id: str) -> None:
|
||||
"""Recursively mark downstream nodes and edges as skipped."""
|
||||
if nodes[node_id].state != NodeState.SKIPPED:
|
||||
return
|
||||
# If this node is skipped, mark all its outgoing edges as skipped
|
||||
out_edge_ids = out_edges.get(node_id, [])
|
||||
for edge_id in out_edge_ids:
|
||||
edge = edges[edge_id]
|
||||
edge.state = NodeState.SKIPPED
|
||||
|
||||
# Check the target node of this edge
|
||||
target_node = nodes[edge.head]
|
||||
in_edge_ids = in_edges.get(target_node.id, [])
|
||||
in_edge_states = [edges[eid].state for eid in in_edge_ids]
|
||||
|
||||
# If all incoming edges are skipped, mark the node as skipped
|
||||
if all(state == NodeState.SKIPPED for state in in_edge_states):
|
||||
target_node.state = NodeState.SKIPPED
|
||||
# Recursively process downstream nodes
|
||||
mark_downstream(target_node.id)
|
||||
|
||||
# Process each inactive root and its downstream nodes
|
||||
for root_id in inactive_roots:
|
||||
mark_downstream(root_id)
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_factory: "NodeFactory",
|
||||
root_node_id: Optional[str] = None,
|
||||
root_node_id: str | None = None,
|
||||
) -> "Graph":
|
||||
"""
|
||||
Initialize graph
|
||||
@ -227,6 +293,9 @@ class Graph:
|
||||
# Get root node instance
|
||||
root_node = nodes[root_node_id]
|
||||
|
||||
# Mark inactive root branches as skipped
|
||||
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
|
||||
|
||||
# Create and return the graph
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
|
||||
Reference in New Issue
Block a user