feat(graph-engine): add command to update variables at runtime (#30563)

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-01-05 16:47:34 +08:00
committed by GitHub
parent 6f8bd58e19
commit a9e2c05a10
8 changed files with 194 additions and 13 deletions

View File

@ -3,8 +3,15 @@
import json
from unittest.mock import MagicMock
from core.variables import IntegerVariable, StringVariable
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
GraphEngineCommand,
UpdateVariablesCommand,
VariableUpdate,
)
class TestRedisChannel:
@ -148,6 +155,43 @@ class TestRedisChannel:
assert commands[0].command_type == CommandType.ABORT
assert isinstance(commands[1], AbortCommand)
def test_fetch_commands_with_update_variables_command(self):
"""Test fetching update variables command from Redis."""
mock_redis = MagicMock()
pending_pipe = MagicMock()
fetch_pipe = MagicMock()
pending_context = MagicMock()
fetch_context = MagicMock()
pending_context.__enter__.return_value = pending_pipe
pending_context.__exit__.return_value = None
fetch_context.__enter__.return_value = fetch_pipe
fetch_context.__exit__.return_value = None
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
),
]
)
command_json = json.dumps(update_command.model_dump())
pending_pipe.execute.return_value = [b"1", 1]
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
channel = RedisChannel(mock_redis, "test:key")
commands = channel.fetch_commands()
assert len(commands) == 1
assert isinstance(commands[0], UpdateVariablesCommand)
assert isinstance(commands[0].updates[0].value, StringVariable)
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
assert commands[0].updates[0].value.value == "bar"
def test_fetch_commands_skips_invalid_json(self):
"""Test that invalid JSON commands are skipped."""
mock_redis = MagicMock()

View File

@ -4,12 +4,19 @@ import time
from unittest.mock import MagicMock
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import IntegerVariable, StringVariable
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.entities.pause_reason import SchedulingPause
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
from core.workflow.graph_engine.entities.commands import (
AbortCommand,
CommandType,
PauseCommand,
UpdateVariablesCommand,
VariableUpdate,
)
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
@ -180,3 +187,67 @@ def test_pause_command():
graph_execution = engine.graph_runtime_state.graph_execution
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
def test_update_variables_command_updates_pool():
"""Test that GraphEngine updates variable pool via update variables command."""
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=shared_runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
command_channel = InMemoryChannel()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=shared_runtime_state,
command_channel=command_channel,
)
update_command = UpdateVariablesCommand(
updates=[
VariableUpdate(
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
),
VariableUpdate(
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
),
]
)
command_channel.send_command(update_command)
list(engine.run())
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
assert updated_existing is not None
assert updated_existing.value == "new value"
assert added_new is not None
assert added_new.value == 123