Merge branch 'feat/queue-based-graph-engine' into feat/rag-2

This commit is contained in:
jyong
2025-08-28 18:12:49 +08:00
71 changed files with 801 additions and 2326 deletions

View File

@ -1,9 +1,9 @@
import logging
from collections import defaultdict
from collections.abc import Mapping
from typing import Any, Optional, Protocol, cast
from typing import Any, Protocol, cast
from core.workflow.enums import NodeType
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from .edge import Edge
@ -36,10 +36,10 @@ class Graph:
def __init__(
self,
*,
nodes: Optional[dict[str, Node]] = None,
edges: Optional[dict[str, Edge]] = None,
in_edges: Optional[dict[str, list[str]]] = None,
out_edges: Optional[dict[str, list[str]]] = None,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
@ -81,7 +81,7 @@ class Graph:
cls,
node_configs_map: dict[str, dict[str, Any]],
edge_configs: list[dict[str, Any]],
root_node_id: Optional[str] = None,
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
@ -186,13 +186,79 @@ class Graph:
return nodes
@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.
Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged
:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]
# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return
# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED
# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED
# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]
# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)
# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)
@classmethod
def init(
cls,
*,
graph_config: Mapping[str, Any],
node_factory: "NodeFactory",
root_node_id: Optional[str] = None,
root_node_id: str | None = None,
) -> "Graph":
"""
Initialize graph
@ -227,6 +293,9 @@ class Graph:
# Get root node instance
root_node = nodes[root_node_id]
# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
return cls(
nodes=nodes,