mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
feat(graph-engine): add command to update variables at runtime (#30563)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@ -3,8 +3,15 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, GraphEngineCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
GraphEngineCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestRedisChannel:
|
||||
@ -148,6 +155,43 @@ class TestRedisChannel:
|
||||
assert commands[0].command_type == CommandType.ABORT
|
||||
assert isinstance(commands[1], AbortCommand)
|
||||
|
||||
def test_fetch_commands_with_update_variables_command(self):
|
||||
"""Test fetching update variables command from Redis."""
|
||||
mock_redis = MagicMock()
|
||||
pending_pipe = MagicMock()
|
||||
fetch_pipe = MagicMock()
|
||||
pending_context = MagicMock()
|
||||
fetch_context = MagicMock()
|
||||
pending_context.__enter__.return_value = pending_pipe
|
||||
pending_context.__exit__.return_value = None
|
||||
fetch_context.__enter__.return_value = fetch_pipe
|
||||
fetch_context.__exit__.return_value = None
|
||||
mock_redis.pipeline.side_effect = [pending_context, fetch_context]
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_json = json.dumps(update_command.model_dump())
|
||||
|
||||
pending_pipe.execute.return_value = [b"1", 1]
|
||||
fetch_pipe.execute.return_value = [[command_json.encode()], 1]
|
||||
|
||||
channel = RedisChannel(mock_redis, "test:key")
|
||||
commands = channel.fetch_commands()
|
||||
|
||||
assert len(commands) == 1
|
||||
assert isinstance(commands[0], UpdateVariablesCommand)
|
||||
assert isinstance(commands[0].updates[0].value, StringVariable)
|
||||
assert list(commands[0].updates[0].value.selector) == ["node1", "foo"]
|
||||
assert commands[0].updates[0].value.value == "bar"
|
||||
|
||||
def test_fetch_commands_skips_invalid_json(self):
|
||||
"""Test that invalid JSON commands are skipped."""
|
||||
mock_redis = MagicMock()
|
||||
|
||||
@ -4,12 +4,19 @@ import time
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.variables import IntegerVariable, StringVariable
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand
|
||||
from core.workflow.graph_engine.entities.commands import (
|
||||
AbortCommand,
|
||||
CommandType,
|
||||
PauseCommand,
|
||||
UpdateVariablesCommand,
|
||||
VariableUpdate,
|
||||
)
|
||||
from core.workflow.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
@ -180,3 +187,67 @@ def test_pause_command():
|
||||
|
||||
graph_execution = engine.graph_runtime_state.graph_execution
|
||||
assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")]
|
||||
|
||||
|
||||
def test_update_variables_command_updates_pool():
|
||||
"""Test that GraphEngine updates variable pool via update variables command."""
|
||||
|
||||
shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
shared_runtime_state.variable_pool.add(("node1", "foo"), "old value")
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
command_channel = InMemoryChannel()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=shared_runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
update_command = UpdateVariablesCommand(
|
||||
updates=[
|
||||
VariableUpdate(
|
||||
value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]),
|
||||
),
|
||||
VariableUpdate(
|
||||
value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]),
|
||||
),
|
||||
]
|
||||
)
|
||||
command_channel.send_command(update_command)
|
||||
|
||||
list(engine.run())
|
||||
|
||||
updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"])
|
||||
added_new = shared_runtime_state.variable_pool.get(["node2", "bar"])
|
||||
|
||||
assert updated_existing is not None
|
||||
assert updated_existing.value == "new value"
|
||||
assert added_new is not None
|
||||
assert added_new.value == 123
|
||||
|
||||
Reference in New Issue
Block a user