mirror of
https://github.com/langgenius/dify.git
synced 2026-03-30 02:20:16 +08:00
WIP: resume
This commit is contained in:
@ -0,0 +1,278 @@
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
API_DIR = str(Path(__file__).resolve().parents[5])
|
||||
if API_DIR not in sys.path:
|
||||
sys.path.insert(0, API_DIR)
|
||||
|
||||
import core.workflow.nodes.human_input.entities # noqa: F401
|
||||
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
|
||||
from core.app.apps.workflow import app_generator as wf_app_gen_module
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine
|
||||
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
|
||||
from core.workflow.graph_engine.entities.commands import PauseCommand
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.end.entities import EndNodeData
|
||||
from core.workflow.nodes.start.entities import StartNodeData
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
if "core.ops.ops_trace_manager" not in sys.modules:
|
||||
ops_stub = ModuleType("core.ops.ops_trace_manager")
|
||||
|
||||
class _StubTraceQueueManager:
|
||||
def __init__(self, *_, **__):
|
||||
pass
|
||||
|
||||
ops_stub.TraceQueueManager = _StubTraceQueueManager
|
||||
sys.modules["core.ops.ops_trace_manager"] = ops_stub
|
||||
|
||||
|
||||
class _StubToolNodeData(BaseNodeData):
|
||||
pass
|
||||
|
||||
|
||||
class _StubToolNode(Node):
|
||||
node_type = NodeType.TOOL
|
||||
|
||||
def init_node_data(self, data):
|
||||
self._node_data = _StubToolNodeData.model_validate(data)
|
||||
|
||||
def _get_error_strategy(self):
|
||||
return self._node_data.error_strategy
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return self._node_data.retry_config
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return self._node_data.title
|
||||
|
||||
def _get_description(self):
|
||||
return self._node_data.desc
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return self._node_data.default_value_dict
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
def _run(self):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"value": f"{self.id}-done"},
|
||||
)
|
||||
|
||||
|
||||
def _patch_tool_node(mocker):
|
||||
from core.workflow.nodes import node_factory
|
||||
|
||||
custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING)
|
||||
custom_versions = dict(custom_mapping[NodeType.TOOL])
|
||||
custom_versions[node_factory.LATEST_VERSION] = _StubToolNode
|
||||
custom_mapping[NodeType.TOOL] = custom_versions
|
||||
mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping)
|
||||
|
||||
|
||||
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
|
||||
params = GraphInitParams(
|
||||
tenant_id="tenant",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
graph_config={},
|
||||
user_id="user",
|
||||
user_from="account",
|
||||
invoke_from="service-api",
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
start_data = StartNodeData(title="start", variables=[])
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": start_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
start_node.init_node_data(start_data.model_dump())
|
||||
|
||||
tool_data = _StubToolNodeData(title="tool")
|
||||
tool_a = _StubToolNode(
|
||||
id="tool_a",
|
||||
config={"id": "tool_a", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_a.init_node_data(tool_data.model_dump())
|
||||
|
||||
tool_b = _StubToolNode(
|
||||
id="tool_b",
|
||||
config={"id": "tool_b", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_b.init_node_data(tool_data.model_dump())
|
||||
|
||||
tool_c = _StubToolNode(
|
||||
id="tool_c",
|
||||
config={"id": "tool_c", "data": tool_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
tool_c.init_node_data(tool_data.model_dump())
|
||||
|
||||
end_data = EndNodeData(
|
||||
title="end",
|
||||
outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])],
|
||||
desc=None,
|
||||
)
|
||||
end_node = EndNode(
|
||||
id="end",
|
||||
config={"id": "end", "data": end_data.model_dump()},
|
||||
graph_init_params=params,
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
end_node.init_node_data(end_data.model_dump())
|
||||
|
||||
return (
|
||||
Graph.new()
|
||||
.add_root(start_node)
|
||||
.add_node(tool_a)
|
||||
.add_node(tool_b)
|
||||
.add_node(tool_c)
|
||||
.add_node(end_node)
|
||||
.add_edge("tool_a", "tool_b")
|
||||
.add_edge("tool_b", "tool_c")
|
||||
.add_edge("tool_c", "end")
|
||||
.build()
|
||||
)
|
||||
|
||||
|
||||
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
|
||||
user_inputs={},
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.system_variables.workflow_execution_id = run_id
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
|
||||
command_channel = InMemoryChannel()
|
||||
graph = _build_graph(runtime_state)
|
||||
engine = GraphEngine(
|
||||
workflow_id="workflow",
|
||||
graph=graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=command_channel,
|
||||
)
|
||||
|
||||
events: list[GraphEngineEvent] = []
|
||||
for event in engine.run():
|
||||
if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on:
|
||||
command_channel.send_command(PauseCommand(reason="test pause"))
|
||||
engine._command_processor.process_commands() # type: ignore[attr-defined]
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
|
||||
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
|
||||
|
||||
|
||||
def test_workflow_app_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("paused-run")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = wf_app_gen_module.WorkflowAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
graph_runtime_state=resumed_state,
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
|
||||
|
||||
def test_advanced_chat_pause_resume_matches_baseline(mocker):
|
||||
_patch_tool_node(mocker)
|
||||
|
||||
baseline_state = _build_runtime_state("adv-baseline")
|
||||
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
|
||||
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
|
||||
baseline_nodes = _node_successes(baseline_events)
|
||||
baseline_outputs = baseline_state.outputs
|
||||
|
||||
paused_state = _build_runtime_state("adv-paused")
|
||||
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
|
||||
assert isinstance(paused_events[-1], GraphRunPausedEvent)
|
||||
paused_nodes = _node_successes(paused_events)
|
||||
snapshot = paused_state.dumps()
|
||||
|
||||
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
|
||||
|
||||
generator = adv_app_gen_module.AdvancedChatAppGenerator()
|
||||
|
||||
def _fake_generate(**kwargs):
|
||||
state: GraphRuntimeState = kwargs["graph_runtime_state"]
|
||||
events = _run_with_optional_pause(state, pause_on=None)
|
||||
return _node_successes(events)
|
||||
|
||||
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
|
||||
|
||||
resumed_nodes = generator.resume(
|
||||
app_model=SimpleNamespace(mode="workflow"),
|
||||
workflow=SimpleNamespace(),
|
||||
user=SimpleNamespace(),
|
||||
conversation=SimpleNamespace(id="conv"),
|
||||
message=SimpleNamespace(id="msg"),
|
||||
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
|
||||
workflow_execution_repository=SimpleNamespace(),
|
||||
workflow_node_execution_repository=SimpleNamespace(),
|
||||
graph_runtime_state=resumed_state,
|
||||
)
|
||||
|
||||
assert paused_nodes + resumed_nodes == baseline_nodes
|
||||
assert resumed_state.outputs == baseline_outputs
|
||||
Reference in New Issue
Block a user