diff --git a/api/core/workflow/graph_engine/command_channels/redis_channel.py b/api/core/workflow/graph_engine/command_channels/redis_channel.py index 4be3adb8f8..0fccd4a0fd 100644 --- a/api/core/workflow/graph_engine/command_channels/redis_channel.py +++ b/api/core/workflow/graph_engine/command_channels/redis_channel.py @@ -9,7 +9,7 @@ Each instance uses a unique key for its command queue. import json from typing import TYPE_CHECKING, Any, final -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand if TYPE_CHECKING: from extensions.ext_redis import RedisClientWrapper @@ -113,6 +113,8 @@ class RedisChannel: return AbortCommand.model_validate(data) if command_type == CommandType.PAUSE: return PauseCommand.model_validate(data) + if command_type == CommandType.UPDATE_VARIABLES: + return UpdateVariablesCommand.model_validate(data) # For other command types, use base class return GraphEngineCommand.model_validate(data) diff --git a/api/core/workflow/graph_engine/command_processing/__init__.py b/api/core/workflow/graph_engine/command_processing/__init__.py index 837f5e55fd..7b4f0dfff7 100644 --- a/api/core/workflow/graph_engine/command_processing/__init__.py +++ b/api/core/workflow/graph_engine/command_processing/__init__.py @@ -5,11 +5,12 @@ This package handles external commands sent to the engine during execution. """ -from .command_handlers import AbortCommandHandler, PauseCommandHandler +from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler from .command_processor import CommandProcessor __all__ = [ "AbortCommandHandler", "CommandProcessor", "PauseCommandHandler", + "UpdateVariablesCommandHandler", ] diff --git a/api/core/workflow/graph_engine/command_processing/command_handlers.py b/api/core/workflow/graph_engine/command_processing/command_handlers.py index e9f109c88c..cfe856d9e8 100644 --- a/api/core/workflow/graph_engine/command_processing/command_handlers.py +++ b/api/core/workflow/graph_engine/command_processing/command_handlers.py @@ -4,9 +4,10 @@ from typing import final from typing_extensions import override from core.workflow.entities.pause_reason import SchedulingPause +from core.workflow.runtime import VariablePool from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand from .command_processor import CommandHandler logger = logging.getLogger(__name__) @@ -31,3 +32,25 @@ class PauseCommandHandler(CommandHandler): reason = command.reason pause_reason = SchedulingPause(message=reason) execution.pause(pause_reason) + + +@final +class UpdateVariablesCommandHandler(CommandHandler): + def __init__(self, variable_pool: VariablePool) -> None: + self._variable_pool = variable_pool + + @override + def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: + assert isinstance(command, UpdateVariablesCommand) + for update in command.updates: + try: + variable = update.value + self._variable_pool.add(variable.selector, variable) + logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) + except ValueError as exc: + logger.warning( + "Skipping invalid variable selector %s for workflow %s: %s", + getattr(update.value, "selector", None), + execution.workflow_id, + exc, + ) diff --git a/api/core/workflow/graph_engine/entities/commands.py b/api/core/workflow/graph_engine/entities/commands.py index 0d51b2b716..6dce03c94d 100644 --- a/api/core/workflow/graph_engine/entities/commands.py +++ b/api/core/workflow/graph_engine/entities/commands.py @@ -5,17 +5,21 @@ This module defines command types that can be sent to a running GraphEngine instance to control its execution flow. """ -from enum import StrEnum +from collections.abc import Sequence +from enum import StrEnum, auto from typing import Any from pydantic import BaseModel, Field +from core.variables.variables import VariableUnion + class CommandType(StrEnum): """Types of commands that can be sent to GraphEngine.""" - ABORT = "abort" - PAUSE = "pause" + ABORT = auto() + PAUSE = auto() + UPDATE_VARIABLES = auto() class GraphEngineCommand(BaseModel): @@ -37,3 +41,16 @@ class PauseCommand(GraphEngineCommand): command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") reason: str = Field(default="unknown reason", description="reason for pause") + + +class VariableUpdate(BaseModel): + """Represents a single variable update instruction.""" + + value: VariableUnion = Field(description="New variable value") + + +class UpdateVariablesCommand(GraphEngineCommand): + """Command to update a group of variables in the variable pool.""" + + command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") + updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index e4816afaf8..88d6e5cac1 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -30,8 +30,13 @@ from core.workflow.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWr if TYPE_CHECKING: # pragma: no cover - used only for static analysis from core.workflow.runtime.graph_runtime_state import GraphProtocol -from .command_processing import AbortCommandHandler, CommandProcessor, PauseCommandHandler -from .entities.commands import AbortCommand, PauseCommand +from .command_processing import ( + AbortCommandHandler, + CommandProcessor, + PauseCommandHandler, + UpdateVariablesCommandHandler, +) +from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand from .error_handler import ErrorHandler from .event_management import EventHandler, EventManager from .graph_state_manager import GraphStateManager @@ -140,6 +145,9 @@ class GraphEngine: pause_handler = PauseCommandHandler() self._command_processor.register_handler(PauseCommand, pause_handler) + update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) + self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) + # === Extensibility === # Layers allow plugins to extend engine functionality self._layers: list[GraphEngineLayer] = [] diff --git a/api/core/workflow/graph_engine/manager.py b/api/core/workflow/graph_engine/manager.py index 0577ba8f02..d2cfa755d9 100644 --- a/api/core/workflow/graph_engine/manager.py +++ b/api/core/workflow/graph_engine/manager.py @@ -3,14 +3,20 @@ GraphEngine Manager for sending control commands via Redis channel. This module provides a simplified interface for controlling workflow executions using the new Redis command channel, without requiring user permission checks. -Supports stop, pause, and resume operations. """ import logging +from collections.abc import Sequence from typing import final from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, GraphEngineCommand, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + GraphEngineCommand, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -23,7 +29,6 @@ class GraphEngineManager: This class provides a simple interface for controlling workflow executions by sending commands through Redis channels, without user validation. - Supports stop and pause operations. """ @staticmethod @@ -45,6 +50,16 @@ class GraphEngineManager: pause_command = PauseCommand(reason=reason or "User requested pause") GraphEngineManager._send_command(task_id, pause_command) + @staticmethod + def send_update_variables_command(task_id: str, updates: Sequence[VariableUpdate]) -> None: + """Send a command to update variables in a running workflow.""" + + if not updates: + return + + update_command = UpdateVariablesCommand(updates=updates) + GraphEngineManager._send_command(task_id, update_command) + @staticmethod def _send_command(task_id: str, command: GraphEngineCommand) -> None: """Send a command to the workflow-specific Redis channel.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py index 8677325d4e..f33fd0deeb 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py @@ -3,8 +3,15 @@ import json from unittest.mock import MagicMock +from core.variables import IntegerVariable, StringVariable from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + GraphEngineCommand, + UpdateVariablesCommand, + VariableUpdate, +) class TestRedisChannel: @@ -148,6 +155,43 @@ class TestRedisChannel: assert commands[0].command_type == CommandType.ABORT assert isinstance(commands[1], AbortCommand) + def test_fetch_commands_with_update_variables_command(self): + """Test fetching update variables command from Redis.""" + mock_redis = MagicMock() + pending_pipe = MagicMock() + fetch_pipe = MagicMock() + pending_context = MagicMock() + fetch_context = MagicMock() + pending_context.__enter__.return_value = pending_pipe + pending_context.__exit__.return_value = None + fetch_context.__enter__.return_value = fetch_pipe + fetch_context.__exit__.return_value = None + mock_redis.pipeline.side_effect = [pending_context, fetch_context] + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), + ), + ] + ) + command_json = json.dumps(update_command.model_dump()) + + pending_pipe.execute.return_value = [b"1", 1] + fetch_pipe.execute.return_value = [[command_json.encode()], 1] + + channel = RedisChannel(mock_redis, "test:key") + commands = channel.fetch_commands() + + assert len(commands) == 1 + assert isinstance(commands[0], UpdateVariablesCommand) + assert isinstance(commands[0].updates[0].value, StringVariable) + assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] + assert commands[0].updates[0].value.value == "bar" + def test_fetch_commands_skips_invalid_json(self): """Test that invalid JSON commands are skipped.""" mock_redis = MagicMock() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py index b074a11be9..d826f7a900 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py @@ -4,12 +4,19 @@ import time from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom +from core.variables import IntegerVariable, StringVariable from core.workflow.entities.graph_init_params import GraphInitParams from core.workflow.entities.pause_reason import SchedulingPause from core.workflow.graph import Graph from core.workflow.graph_engine import GraphEngine from core.workflow.graph_engine.command_channels import InMemoryChannel -from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand +from core.workflow.graph_engine.entities.commands import ( + AbortCommand, + CommandType, + PauseCommand, + UpdateVariablesCommand, + VariableUpdate, +) from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent from core.workflow.nodes.start.start_node import StartNode from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -180,3 +187,67 @@ def test_pause_command(): graph_execution = engine.graph_runtime_state.graph_execution assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] + + +def test_update_variables_command_updates_pool(): + """Test that GraphEngine updates variable pool via update variables command.""" + + shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) + shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") + + mock_graph = MagicMock(spec=Graph) + mock_graph.nodes = {} + mock_graph.edges = {} + mock_graph.root_node = MagicMock() + mock_graph.root_node.id = "start" + + start_node = StartNode( + id="start", + config={"id": "start", "data": {"title": "start", "variables": []}}, + graph_init_params=GraphInitParams( + tenant_id="test_tenant", + app_id="test_app", + workflow_id="test_workflow", + graph_config={}, + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ), + graph_runtime_state=shared_runtime_state, + ) + mock_graph.nodes["start"] = start_node + + mock_graph.get_outgoing_edges = MagicMock(return_value=[]) + mock_graph.get_incoming_edges = MagicMock(return_value=[]) + + command_channel = InMemoryChannel() + + engine = GraphEngine( + workflow_id="test_workflow", + graph=mock_graph, + graph_runtime_state=shared_runtime_state, + command_channel=command_channel, + ) + + update_command = UpdateVariablesCommand( + updates=[ + VariableUpdate( + value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), + ), + VariableUpdate( + value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), + ), + ] + ) + command_channel.send_command(update_command) + + list(engine.run()) + + updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) + added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) + + assert updated_existing is not None + assert updated_existing.value == "new value" + assert added_new is not None + assert added_new.value == 123