mirror of
https://github.com/langgenius/dify.git
synced 2026-03-18 13:19:56 +08:00
refactor: move workflow package to dify_graph (#32844)
This commit is contained in:
150
api/dify_graph/graph_engine/layers/execution_limits.py
Normal file
150
api/dify_graph/graph_engine/layers/execution_limits.py
Normal file
@ -0,0 +1,150 @@
|
||||
"""
|
||||
Execution limits layer for GraphEngine.
|
||||
|
||||
This layer monitors workflow execution to enforce limits on:
|
||||
- Maximum execution steps
|
||||
- Maximum execution time
|
||||
|
||||
When limits are exceeded, the layer automatically aborts execution.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from enum import StrEnum
|
||||
from typing import final
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType
|
||||
from dify_graph.graph_engine.layers import GraphEngineLayer
|
||||
from dify_graph.graph_events import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent
|
||||
|
||||
|
||||
class LimitType(StrEnum):
|
||||
"""Types of execution limits that can be exceeded."""
|
||||
|
||||
STEP_LIMIT = "step_limit"
|
||||
TIME_LIMIT = "time_limit"
|
||||
|
||||
|
||||
@final
|
||||
class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
"""
|
||||
Layer that enforces execution limits for workflows.
|
||||
|
||||
Monitors:
|
||||
- Step count: Tracks number of node executions
|
||||
- Time limit: Monitors total execution time
|
||||
|
||||
Automatically aborts execution when limits are exceeded.
|
||||
"""
|
||||
|
||||
def __init__(self, max_steps: int, max_time: int) -> None:
|
||||
"""
|
||||
Initialize the execution limits layer.
|
||||
|
||||
Args:
|
||||
max_steps: Maximum number of execution steps allowed
|
||||
max_time: Maximum execution time in seconds allowed
|
||||
"""
|
||||
super().__init__()
|
||||
self.max_steps = max_steps
|
||||
self.max_time = max_time
|
||||
|
||||
# Runtime tracking
|
||||
self.start_time: float | None = None
|
||||
self.step_count = 0
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# State tracking
|
||||
self._execution_started = False
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False # Track if abort command has been sent
|
||||
|
||||
@override
|
||||
def on_graph_start(self) -> None:
|
||||
"""Called when graph execution starts."""
|
||||
self.start_time = time.time()
|
||||
self.step_count = 0
|
||||
self._execution_started = True
|
||||
self._execution_ended = False
|
||||
self._abort_sent = False
|
||||
|
||||
self.logger.debug("Execution limits monitoring started")
|
||||
|
||||
@override
|
||||
def on_event(self, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Called for every event emitted by the engine.
|
||||
|
||||
Monitors execution progress and enforces limits.
|
||||
"""
|
||||
if not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Track step count for node execution events
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
self.step_count += 1
|
||||
self.logger.debug("Step %d started: %s", self.step_count, event.node_id)
|
||||
|
||||
# Check step limit when node execution completes
|
||||
if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent):
|
||||
if self._reached_step_limitation():
|
||||
self._send_abort_command(LimitType.STEP_LIMIT)
|
||||
|
||||
if self._reached_time_limitation():
|
||||
self._send_abort_command(LimitType.TIME_LIMIT)
|
||||
|
||||
@override
|
||||
def on_graph_end(self, error: Exception | None) -> None:
|
||||
"""Called when graph execution ends."""
|
||||
if self._execution_started and not self._execution_ended:
|
||||
self._execution_ended = True
|
||||
|
||||
if self.start_time:
|
||||
total_time = time.time() - self.start_time
|
||||
self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time)
|
||||
|
||||
def _reached_step_limitation(self) -> bool:
|
||||
"""Check if step count limit has been exceeded."""
|
||||
return self.step_count > self.max_steps
|
||||
|
||||
def _reached_time_limitation(self) -> bool:
|
||||
"""Check if time limit has been exceeded."""
|
||||
return self.start_time is not None and (time.time() - self.start_time) > self.max_time
|
||||
|
||||
def _send_abort_command(self, limit_type: LimitType) -> None:
|
||||
"""
|
||||
Send abort command due to limit violation.
|
||||
|
||||
Args:
|
||||
limit_type: Type of limit exceeded
|
||||
"""
|
||||
if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent:
|
||||
return
|
||||
|
||||
# Format detailed reason message
|
||||
if limit_type == LimitType.STEP_LIMIT:
|
||||
reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}"
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
try:
|
||||
# Send abort command to the engine
|
||||
abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason)
|
||||
self.command_channel.send_command(abort_command)
|
||||
|
||||
# Mark that abort has been sent to prevent duplicate commands
|
||||
self._abort_sent = True
|
||||
|
||||
self.logger.debug("Abort command sent to engine")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception("Failed to send abort command")
|
||||
Reference in New Issue
Block a user