mirror of
https://github.com/langgenius/dify.git
synced 2026-04-27 05:58:14 +08:00
refactor(workflow): inject redis into graph engine manager (#32622)
This commit is contained in:
@ -7,12 +7,28 @@ Each instance uses a unique key for its command queue.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, final
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Any, Protocol, final
|
||||
|
||||
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from extensions.ext_redis import RedisClientWrapper
|
||||
|
||||
class RedisPipelineProtocol(Protocol):
|
||||
"""Minimal Redis pipeline contract used by the command channel."""
|
||||
|
||||
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
||||
def delete(self, *names: str) -> Any: ...
|
||||
def execute(self) -> list[Any]: ...
|
||||
def rpush(self, name: str, *values: str) -> Any: ...
|
||||
def expire(self, name: str, time: int) -> Any: ...
|
||||
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
||||
def get(self, name: str) -> Any: ...
|
||||
|
||||
|
||||
class RedisClientProtocol(Protocol):
|
||||
"""Redis client contract required by the command channel."""
|
||||
|
||||
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
||||
|
||||
|
||||
@final
|
||||
@ -26,7 +42,7 @@ class RedisChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: "RedisClientWrapper",
|
||||
redis_client: RedisClientProtocol,
|
||||
channel_key: str,
|
||||
command_ttl: int = 3600,
|
||||
) -> None:
|
||||
|
||||
@ -3,13 +3,14 @@ GraphEngine Manager for sending control commands via Redis channel.
|
||||
|
||||
This module provides a simplified interface for controlling workflow executions
|
||||
using the new Redis command channel, without requiring user permission checks.
|
||||
Callers must provide a Redis client dependency from outside the workflow package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import final
|
||||
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
GraphEngineCommand,
|
||||
@ -17,7 +18,6 @@ from core.workflow.graph_engine.entities.commands import (
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -31,8 +31,12 @@ class GraphEngineManager:
|
||||
by sending commands through Redis channels, without user validation.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def send_stop_command(task_id: str, reason: str | None = None) -> None:
|
||||
_redis_client: RedisClientProtocol
|
||||
|
||||
def __init__(self, redis_client: RedisClientProtocol) -> None:
|
||||
self._redis_client = redis_client
|
||||
|
||||
def send_stop_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""
|
||||
Send a stop command to a running workflow.
|
||||
|
||||
@ -41,34 +45,31 @@ class GraphEngineManager:
|
||||
reason: Optional reason for stopping (defaults to "User requested stop")
|
||||
"""
|
||||
abort_command = AbortCommand(reason=reason or "User requested stop")
|
||||
GraphEngineManager._send_command(task_id, abort_command)
|
||||
self._send_command(task_id, abort_command)
|
||||
|
||||
@staticmethod
|
||||
def send_pause_command(task_id: str, reason: str | None = None) -> None:
|
||||
def send_pause_command(self, task_id: str, reason: str | None = None) -> None:
|
||||
"""Send a pause command to a running workflow."""
|
||||
|
||||
pause_command = PauseCommand(reason=reason or "User requested pause")
|
||||
GraphEngineManager._send_command(task_id, pause_command)
|
||||
self._send_command(task_id, pause_command)
|
||||
|
||||
@staticmethod
|
||||
def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None:
|
||||
"""Send a command to update variables in a running workflow."""
|
||||
|
||||
if not updates:
|
||||
return
|
||||
|
||||
update_command = UpdateVariablesCommand(updates=updates)
|
||||
GraphEngineManager._send_command(task_id, update_command)
|
||||
self._send_command(task_id, update_command)
|
||||
|
||||
@staticmethod
|
||||
def _send_command(task_id: str, command: GraphEngineCommand) -> None:
|
||||
def _send_command(self, task_id: str, command: GraphEngineCommand) -> None:
|
||||
"""Send a command to the workflow-specific Redis channel."""
|
||||
|
||||
if not task_id:
|
||||
return
|
||||
|
||||
channel_key = f"workflow:{task_id}:commands"
|
||||
channel = RedisChannel(redis_client, channel_key)
|
||||
channel = RedisChannel(self._redis_client, channel_key)
|
||||
|
||||
try:
|
||||
channel.send_command(command)
|
||||
|
||||
Reference in New Issue
Block a user