Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
lyzno1
2025-10-15 20:39:17 +08:00
73 changed files with 635 additions and 670 deletions

View File

@ -41,6 +41,7 @@ class RedisChannel:
self._redis = redis_client
self._key = channel_key
self._command_ttl = command_ttl
self._pending_key = f"{channel_key}:pending"
def fetch_commands(self) -> list[GraphEngineCommand]:
"""
@ -49,6 +50,9 @@ class RedisChannel:
Returns:
List of pending commands (drains the Redis list)
"""
if not self._has_pending_commands():
return []
commands: list[GraphEngineCommand] = []
# Use pipeline for atomic operations
@ -85,6 +89,7 @@ class RedisChannel:
with self._redis.pipeline() as pipe:
pipe.rpush(self._key, command_json)
pipe.expire(self._key, self._command_ttl)
pipe.set(self._pending_key, "1", ex=self._command_ttl)
pipe.execute()
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
@ -112,3 +117,17 @@ class RedisChannel:
except (ValueError, TypeError):
return None
def _has_pending_commands(self) -> bool:
"""
Check and consume the pending marker to avoid unnecessary list reads.
Returns:
True if commands should be fetched from Redis.
"""
with self._redis.pipeline() as pipe:
pipe.get(self._pending_key)
pipe.delete(self._pending_key)
pending_value, _ = pipe.execute()
return pending_value is not None

View File

@ -7,6 +7,7 @@ from collections.abc import Mapping
from functools import singledispatchmethod
from typing import TYPE_CHECKING, final
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import GraphRuntimeState
from core.workflow.enums import ErrorStrategy, NodeExecutionType
from core.workflow.graph import Graph
@ -125,6 +126,7 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
is_initial_attempt = node_execution.retry_count == 0
node_execution.mark_started(event.id)
self._graph_runtime_state.increment_node_run_steps()
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)
@ -163,6 +165,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Store outputs in variable pool
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -212,6 +216,8 @@ class EventHandler:
node_execution.mark_failed(event.error)
self._graph_execution.record_node_failure()
self._accumulate_node_usage(event.node_run_result.llm_usage)
result = self._error_handler.handle_node_failure(event)
if result:
@ -235,6 +241,8 @@ class EventHandler:
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
node_execution.mark_taken()
self._accumulate_node_usage(event.node_run_result.llm_usage)
# Persist outputs produced by the exception strategy (e.g. default values)
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
@ -286,6 +294,19 @@ class EventHandler:
self._state_manager.enqueue_node(event.node_id)
self._state_manager.start_execution(event.node_id)
def _accumulate_node_usage(self, usage: LLMUsage) -> None:
"""Accumulate token usage into the shared runtime state."""
if usage.total_tokens <= 0:
return
self._graph_runtime_state.add_tokens(usage.total_tokens)
current_usage = self._graph_runtime_state.llm_usage
if current_usage.total_tokens == 0:
self._graph_runtime_state.llm_usage = usage
else:
self._graph_runtime_state.llm_usage = current_usage.plus(usage)
def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None:
"""
Store node outputs in the variable pool.

View File

@ -8,7 +8,12 @@ import threading
import time
from typing import TYPE_CHECKING, final
from core.workflow.graph_events.base import GraphNodeEventBase
from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunSucceededEvent,
)
from ..event_management import EventManager
from .execution_coordinator import ExecutionCoordinator
@ -72,13 +77,16 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
self._thread.join(timeout=10.0)
_COMMAND_TRIGGER_EVENTS = (
NodeRunSucceededEvent,
NodeRunFailedEvent,
NodeRunExceptionEvent,
)
def _dispatcher_loop(self) -> None:
"""Main dispatcher loop."""
try:
while not self._stop_event.is_set():
# Check for commands
self._execution_coordinator.check_commands()
# Check for scaling
self._execution_coordinator.check_scaling()
@ -87,6 +95,8 @@ class Dispatcher:
event = self._event_queue.get(timeout=0.1)
# Route to the event handler
self._event_handler.dispatch(event)
if self._should_check_commands(event):
self._execution_coordinator.check_commands()
self._event_queue.task_done()
except queue.Empty:
# Check if execution is complete
@ -102,3 +112,7 @@ class Dispatcher:
# Signal the event emitter that execution is complete
if self._event_emitter:
self._event_emitter.mark_complete()
def _should_check_commands(self, event: GraphNodeEventBase) -> bool:
"""Return True if the event represents a node completion."""
return isinstance(event, self._COMMAND_TRIGGER_EVENTS)