mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
feat: unified management stop event (#30479)
This commit is contained in:
@ -8,6 +8,7 @@ Domain-Driven Design principles for improved maintainability and testability.
|
||||
import contextvars
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@ -75,10 +76,13 @@ class GraphEngine:
|
||||
scale_down_idle_time: float | None = None,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
|
||||
@ -177,6 +181,7 @@ class GraphEngine:
|
||||
max_workers=self._max_workers,
|
||||
scale_up_threshold=self._scale_up_threshold,
|
||||
scale_down_idle_time=self._scale_down_idle_time,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@ -207,6 +212,7 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@ -324,6 +330,7 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
@ -351,13 +358,12 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
# Notify layers
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
|
||||
@ -44,6 +44,7 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -61,7 +62,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = threading.Event()
|
||||
self._stop_event = stop_event
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@ -69,16 +70,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=10.0)
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
def _dispatcher_loop(self) -> None:
|
||||
"""Main dispatcher loop."""
|
||||
|
||||
@ -42,6 +42,7 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
flask_app: Flask | None = None,
|
||||
context_vars: contextvars.Context | None = None,
|
||||
@ -65,13 +66,16 @@ class Worker(threading.Thread):
|
||||
self._worker_id = worker_id
|
||||
self._flask_app = flask_app
|
||||
self._context_vars = context_vars
|
||||
self._stop_event = threading.Event()
|
||||
self._last_task_time = time.time()
|
||||
self._stop_event = stop_event
|
||||
self._layers = layers if layers is not None else []
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@ -41,6 +41,7 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
flask_app: "Flask | None" = None,
|
||||
context_vars: "Context | None" = None,
|
||||
min_workers: int | None = None,
|
||||
@ -81,6 +82,7 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@ -135,7 +137,7 @@ class WorkerPool:
|
||||
# Wait for workers to finish
|
||||
for worker in self._workers:
|
||||
if worker.is_alive():
|
||||
worker.join(timeout=10.0)
|
||||
worker.join(timeout=2.0)
|
||||
|
||||
self._workers.clear()
|
||||
|
||||
@ -152,6 +154,7 @@ class WorkerPool:
|
||||
worker_id=worker_id,
|
||||
flask_app=self._flask_app,
|
||||
context_vars=self._context_vars,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@ -264,6 +264,10 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@ -332,6 +336,21 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@ -168,6 +169,7 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
Reference in New Issue
Block a user