mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 15:26:21 +08:00
360 lines
14 KiB
Python
360 lines
14 KiB
Python
"""
|
|
QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution.
|
|
|
|
This engine uses a modular architecture with separated packages following
|
|
Domain-Driven Design principles for improved maintainability and testability.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import queue
|
|
from collections.abc import Generator
|
|
from typing import TYPE_CHECKING, cast, final
|
|
|
|
from dify_graph.context import capture_current_context
|
|
from dify_graph.entities.workflow_start_reason import WorkflowStartReason
|
|
from dify_graph.enums import NodeExecutionType
|
|
from dify_graph.graph import Graph
|
|
from dify_graph.graph_events import (
|
|
GraphEngineEvent,
|
|
GraphNodeEventBase,
|
|
GraphRunAbortedEvent,
|
|
GraphRunFailedEvent,
|
|
GraphRunPartialSucceededEvent,
|
|
GraphRunPausedEvent,
|
|
GraphRunStartedEvent,
|
|
GraphRunSucceededEvent,
|
|
)
|
|
from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper
|
|
|
|
if TYPE_CHECKING: # pragma: no cover - used only for static analysis
|
|
from dify_graph.runtime.graph_runtime_state import GraphProtocol
|
|
|
|
from .command_processing import (
|
|
AbortCommandHandler,
|
|
CommandProcessor,
|
|
PauseCommandHandler,
|
|
UpdateVariablesCommandHandler,
|
|
)
|
|
from .config import GraphEngineConfig
|
|
from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand
|
|
from .error_handler import ErrorHandler
|
|
from .event_management import EventHandler, EventManager
|
|
from .graph_state_manager import GraphStateManager
|
|
from .graph_traversal import EdgeProcessor, SkipPropagator
|
|
from .layers.base import GraphEngineLayer
|
|
from .orchestration import Dispatcher, ExecutionCoordinator
|
|
from .protocols.command_channel import CommandChannel
|
|
from .worker_management import WorkerPool
|
|
|
|
if TYPE_CHECKING:
|
|
from dify_graph.graph_engine.domain.graph_execution import GraphExecution
|
|
from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
_DEFAULT_CONFIG = GraphEngineConfig()
|
|
|
|
|
|
@final
|
|
class GraphEngine:
|
|
"""
|
|
Queue-based graph execution engine.
|
|
|
|
Uses a modular architecture that delegates responsibilities to specialized
|
|
subsystems, following Domain-Driven Design and SOLID principles.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
workflow_id: str,
|
|
graph: Graph,
|
|
graph_runtime_state: GraphRuntimeState,
|
|
command_channel: CommandChannel,
|
|
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
|
) -> None:
|
|
"""Initialize the graph engine with all subsystems and dependencies."""
|
|
|
|
# Bind runtime state to current workflow context
|
|
self._graph = graph
|
|
self._graph_runtime_state = graph_runtime_state
|
|
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
|
self._command_channel = command_channel
|
|
self._config = config
|
|
|
|
# Graph execution tracks the overall execution state
|
|
self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution)
|
|
self._graph_execution.workflow_id = workflow_id
|
|
|
|
# === Execution Queues ===
|
|
self._ready_queue = self._graph_runtime_state.ready_queue
|
|
|
|
# Queue for events generated during execution
|
|
self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue()
|
|
|
|
# === State Management ===
|
|
# Unified state manager handles all node state transitions and queue operations
|
|
self._state_manager = GraphStateManager(self._graph, self._ready_queue)
|
|
|
|
# === Response Coordination ===
|
|
# Coordinates response streaming from response nodes
|
|
self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator)
|
|
|
|
# === Event Management ===
|
|
# Event manager handles both collection and emission of events
|
|
self._event_manager = EventManager()
|
|
|
|
# === Error Handling ===
|
|
# Centralized error handler for graph execution errors
|
|
self._error_handler = ErrorHandler(self._graph, self._graph_execution)
|
|
|
|
# === Graph Traversal Components ===
|
|
# Propagates skip status through the graph when conditions aren't met
|
|
self._skip_propagator = SkipPropagator(
|
|
graph=self._graph,
|
|
state_manager=self._state_manager,
|
|
)
|
|
|
|
# Processes edges to determine next nodes after execution
|
|
# Also handles conditional branching and route selection
|
|
self._edge_processor = EdgeProcessor(
|
|
graph=self._graph,
|
|
state_manager=self._state_manager,
|
|
response_coordinator=self._response_coordinator,
|
|
skip_propagator=self._skip_propagator,
|
|
)
|
|
|
|
# === Command Processing ===
|
|
# Processes external commands (e.g., abort requests)
|
|
self._command_processor = CommandProcessor(
|
|
command_channel=self._command_channel,
|
|
graph_execution=self._graph_execution,
|
|
)
|
|
|
|
# Register command handlers
|
|
abort_handler = AbortCommandHandler()
|
|
self._command_processor.register_handler(AbortCommand, abort_handler)
|
|
|
|
pause_handler = PauseCommandHandler()
|
|
self._command_processor.register_handler(PauseCommand, pause_handler)
|
|
|
|
update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool)
|
|
self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler)
|
|
|
|
# === Extensibility ===
|
|
# Layers allow plugins to extend engine functionality
|
|
self._layers: list[GraphEngineLayer] = []
|
|
|
|
# === Worker Pool Setup ===
|
|
# Capture execution context for worker threads
|
|
execution_context = capture_current_context()
|
|
|
|
# Create worker pool for parallel node execution
|
|
self._worker_pool = WorkerPool(
|
|
ready_queue=self._ready_queue,
|
|
event_queue=self._event_queue,
|
|
graph=self._graph,
|
|
layers=self._layers,
|
|
execution_context=execution_context,
|
|
config=self._config,
|
|
)
|
|
|
|
# === Orchestration ===
|
|
# Coordinates the overall execution lifecycle
|
|
self._execution_coordinator = ExecutionCoordinator(
|
|
graph_execution=self._graph_execution,
|
|
state_manager=self._state_manager,
|
|
command_processor=self._command_processor,
|
|
worker_pool=self._worker_pool,
|
|
)
|
|
|
|
# === Event Handler Registry ===
|
|
# Central registry for handling all node execution events
|
|
self._event_handler_registry = EventHandler(
|
|
graph=self._graph,
|
|
graph_runtime_state=self._graph_runtime_state,
|
|
graph_execution=self._graph_execution,
|
|
response_coordinator=self._response_coordinator,
|
|
event_collector=self._event_manager,
|
|
edge_processor=self._edge_processor,
|
|
state_manager=self._state_manager,
|
|
error_handler=self._error_handler,
|
|
)
|
|
|
|
# Dispatches events and manages execution flow
|
|
self._dispatcher = Dispatcher(
|
|
event_queue=self._event_queue,
|
|
event_handler=self._event_handler_registry,
|
|
execution_coordinator=self._execution_coordinator,
|
|
event_emitter=self._event_manager,
|
|
)
|
|
|
|
# === Validation ===
|
|
# Ensure all nodes share the same GraphRuntimeState instance
|
|
self._validate_graph_state_consistency()
|
|
|
|
def _validate_graph_state_consistency(self) -> None:
|
|
"""Validate that all nodes share the same GraphRuntimeState."""
|
|
expected_state_id = id(self._graph_runtime_state)
|
|
for node in self._graph.nodes.values():
|
|
if id(node.graph_runtime_state) != expected_state_id:
|
|
raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance")
|
|
|
|
def _bind_layer_context(
|
|
self,
|
|
layer: GraphEngineLayer,
|
|
) -> None:
|
|
layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel)
|
|
|
|
def layer(self, layer: GraphEngineLayer) -> GraphEngine:
|
|
"""Add a layer for extending functionality."""
|
|
self._layers.append(layer)
|
|
self._bind_layer_context(layer)
|
|
return self
|
|
|
|
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
|
"""
|
|
Execute the graph using the modular architecture.
|
|
|
|
Returns:
|
|
Generator yielding GraphEngineEvent instances
|
|
"""
|
|
try:
|
|
# Initialize layers
|
|
self._initialize_layers()
|
|
|
|
is_resume = self._graph_execution.started
|
|
if not is_resume:
|
|
self._graph_execution.start()
|
|
else:
|
|
self._graph_execution.paused = False
|
|
self._graph_execution.pause_reasons = []
|
|
|
|
start_event = GraphRunStartedEvent(
|
|
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
|
|
)
|
|
self._event_manager.notify_layers(start_event)
|
|
yield start_event
|
|
|
|
# Start subsystems
|
|
self._start_execution(resume=is_resume)
|
|
|
|
# Yield events as they occur
|
|
yield from self._event_manager.emit_events()
|
|
|
|
# Handle completion
|
|
if self._graph_execution.is_paused:
|
|
pause_reasons = self._graph_execution.pause_reasons
|
|
assert pause_reasons, "pause_reasons should not be empty when execution is paused."
|
|
# Ensure we have a valid PauseReason for the event
|
|
paused_event = GraphRunPausedEvent(
|
|
reasons=pause_reasons,
|
|
outputs=self._graph_runtime_state.outputs,
|
|
)
|
|
self._event_manager.notify_layers(paused_event)
|
|
yield paused_event
|
|
elif self._graph_execution.aborted:
|
|
abort_reason = "Workflow execution aborted by user command"
|
|
if self._graph_execution.error:
|
|
abort_reason = str(self._graph_execution.error)
|
|
aborted_event = GraphRunAbortedEvent(
|
|
reason=abort_reason,
|
|
outputs=self._graph_runtime_state.outputs,
|
|
)
|
|
self._event_manager.notify_layers(aborted_event)
|
|
yield aborted_event
|
|
elif self._graph_execution.has_error:
|
|
if self._graph_execution.error:
|
|
raise self._graph_execution.error
|
|
else:
|
|
outputs = self._graph_runtime_state.outputs
|
|
exceptions_count = self._graph_execution.exceptions_count
|
|
if exceptions_count > 0:
|
|
partial_event = GraphRunPartialSucceededEvent(
|
|
exceptions_count=exceptions_count,
|
|
outputs=outputs,
|
|
)
|
|
self._event_manager.notify_layers(partial_event)
|
|
yield partial_event
|
|
else:
|
|
succeeded_event = GraphRunSucceededEvent(
|
|
outputs=outputs,
|
|
)
|
|
self._event_manager.notify_layers(succeeded_event)
|
|
yield succeeded_event
|
|
|
|
except Exception as e:
|
|
failed_event = GraphRunFailedEvent(
|
|
error=str(e),
|
|
exceptions_count=self._graph_execution.exceptions_count,
|
|
)
|
|
self._event_manager.notify_layers(failed_event)
|
|
yield failed_event
|
|
raise
|
|
|
|
finally:
|
|
self._stop_execution()
|
|
|
|
def _initialize_layers(self) -> None:
|
|
"""Initialize layers with context."""
|
|
self._event_manager.set_layers(self._layers)
|
|
for layer in self._layers:
|
|
try:
|
|
layer.on_graph_start()
|
|
except Exception:
|
|
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
|
|
|
def _start_execution(self, *, resume: bool = False) -> None:
|
|
"""Start execution subsystems."""
|
|
paused_nodes: list[str] = []
|
|
deferred_nodes: list[str] = []
|
|
if resume:
|
|
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
|
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
|
|
|
|
# Start worker pool (it calculates initial workers internally)
|
|
self._worker_pool.start()
|
|
|
|
# Register response nodes
|
|
for node in self._graph.nodes.values():
|
|
if node.execution_type == NodeExecutionType.RESPONSE:
|
|
self._response_coordinator.register(node.id)
|
|
|
|
if not resume:
|
|
# Enqueue root node
|
|
root_node = self._graph.root_node
|
|
self._state_manager.enqueue_node(root_node.id)
|
|
self._state_manager.start_execution(root_node.id)
|
|
else:
|
|
seen_nodes: set[str] = set()
|
|
for node_id in paused_nodes + deferred_nodes:
|
|
if node_id in seen_nodes:
|
|
continue
|
|
seen_nodes.add(node_id)
|
|
self._state_manager.enqueue_node(node_id)
|
|
self._state_manager.start_execution(node_id)
|
|
|
|
# Start dispatcher
|
|
self._dispatcher.start()
|
|
|
|
def _stop_execution(self) -> None:
|
|
"""Stop execution subsystems."""
|
|
self._dispatcher.stop()
|
|
self._worker_pool.stop()
|
|
# Don't mark complete here as the dispatcher already does it
|
|
|
|
# Notify layers
|
|
for layer in self._layers:
|
|
try:
|
|
layer.on_graph_end(self._graph_execution.error)
|
|
except Exception:
|
|
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
|
|
|
# Public property accessors for attributes that need external access
|
|
@property
|
|
def graph_runtime_state(self) -> GraphRuntimeState:
|
|
"""Get the graph runtime state."""
|
|
return self._graph_runtime_state
|