mirror of
https://github.com/langgenius/dify.git
synced 2026-03-04 23:36:20 +08:00
154 lines
4.9 KiB
Python
154 lines
4.9 KiB
Python
"""
|
|
Redis-based implementation of CommandChannel for distributed scenarios.
|
|
|
|
This implementation uses Redis lists for command queuing, supporting
|
|
multi-instance deployments and cross-server communication.
|
|
Each instance uses a unique key for its command queue.
|
|
"""
|
|
|
|
import json
|
|
from contextlib import AbstractContextManager
|
|
from typing import Any, Protocol, final
|
|
|
|
from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand
|
|
|
|
|
|
class RedisPipelineProtocol(Protocol):
|
|
"""Minimal Redis pipeline contract used by the command channel."""
|
|
|
|
def lrange(self, name: str, start: int, end: int) -> Any: ...
|
|
def delete(self, *names: str) -> Any: ...
|
|
def execute(self) -> list[Any]: ...
|
|
def rpush(self, name: str, *values: str) -> Any: ...
|
|
def expire(self, name: str, time: int) -> Any: ...
|
|
def set(self, name: str, value: str, ex: int | None = None) -> Any: ...
|
|
def get(self, name: str) -> Any: ...
|
|
|
|
|
|
class RedisClientProtocol(Protocol):
|
|
"""Redis client contract required by the command channel."""
|
|
|
|
def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ...
|
|
|
|
|
|
@final
|
|
class RedisChannel:
|
|
"""
|
|
Redis-based command channel implementation for distributed systems.
|
|
|
|
Each instance uses a unique Redis key for its command queue.
|
|
Commands are JSON-serialized for transport.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
redis_client: RedisClientProtocol,
|
|
channel_key: str,
|
|
command_ttl: int = 3600,
|
|
) -> None:
|
|
"""
|
|
Initialize the Redis channel.
|
|
|
|
Args:
|
|
redis_client: Redis client instance
|
|
channel_key: Unique key for this channel's command queue
|
|
command_ttl: TTL for command keys in seconds (default: 3600)
|
|
"""
|
|
self._redis = redis_client
|
|
self._key = channel_key
|
|
self._command_ttl = command_ttl
|
|
self._pending_key = f"{channel_key}:pending"
|
|
|
|
def fetch_commands(self) -> list[GraphEngineCommand]:
|
|
"""
|
|
Fetch all pending commands from Redis.
|
|
|
|
Returns:
|
|
List of pending commands (drains the Redis list)
|
|
"""
|
|
if not self._has_pending_commands():
|
|
return []
|
|
|
|
commands: list[GraphEngineCommand] = []
|
|
|
|
# Use pipeline for atomic operations
|
|
with self._redis.pipeline() as pipe:
|
|
# Get all commands and clear the list atomically
|
|
pipe.lrange(self._key, 0, -1)
|
|
pipe.delete(self._key)
|
|
results = pipe.execute()
|
|
|
|
# Parse commands from JSON
|
|
if results[0]:
|
|
for command_json in results[0]:
|
|
try:
|
|
command_data = json.loads(command_json)
|
|
command = self._deserialize_command(command_data)
|
|
if command:
|
|
commands.append(command)
|
|
except (json.JSONDecodeError, ValueError):
|
|
# Skip invalid commands
|
|
continue
|
|
|
|
return commands
|
|
|
|
def send_command(self, command: GraphEngineCommand) -> None:
|
|
"""
|
|
Send a command to Redis.
|
|
|
|
Args:
|
|
command: The command to send
|
|
"""
|
|
command_json = json.dumps(command.model_dump())
|
|
|
|
# Push to list and set expiry
|
|
with self._redis.pipeline() as pipe:
|
|
pipe.rpush(self._key, command_json)
|
|
pipe.expire(self._key, self._command_ttl)
|
|
pipe.set(self._pending_key, "1", ex=self._command_ttl)
|
|
pipe.execute()
|
|
|
|
def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None:
|
|
"""
|
|
Deserialize a command from dictionary data.
|
|
|
|
Args:
|
|
data: Command data dictionary
|
|
|
|
Returns:
|
|
Deserialized command or None if invalid
|
|
"""
|
|
command_type_value = data.get("command_type")
|
|
if not isinstance(command_type_value, str):
|
|
return None
|
|
|
|
try:
|
|
command_type = CommandType(command_type_value)
|
|
|
|
if command_type == CommandType.ABORT:
|
|
return AbortCommand.model_validate(data)
|
|
if command_type == CommandType.PAUSE:
|
|
return PauseCommand.model_validate(data)
|
|
if command_type == CommandType.UPDATE_VARIABLES:
|
|
return UpdateVariablesCommand.model_validate(data)
|
|
|
|
# For other command types, use base class
|
|
return GraphEngineCommand.model_validate(data)
|
|
|
|
except (ValueError, TypeError):
|
|
return None
|
|
|
|
def _has_pending_commands(self) -> bool:
|
|
"""
|
|
Check and consume the pending marker to avoid unnecessary list reads.
|
|
|
|
Returns:
|
|
True if commands should be fetched from Redis.
|
|
"""
|
|
with self._redis.pipeline() as pipe:
|
|
pipe.get(self._pending_key)
|
|
pipe.delete(self._pending_key)
|
|
pending_value, _ = pipe.execute()
|
|
|
|
return pending_value is not None
|