WIP: resume

This commit is contained in:
QuantumGhost
2025-11-21 10:13:20 +08:00
parent c0e15b9e1b
commit c0f1aeddbe
49 changed files with 2160 additions and 1445 deletions

View File

@ -0,0 +1,278 @@
import sys
import time
from pathlib import Path
from types import ModuleType, SimpleNamespace
from typing import Any
API_DIR = str(Path(__file__).resolve().parents[5])
if API_DIR not in sys.path:
sys.path.insert(0, API_DIR)
import core.workflow.nodes.human_input.entities # noqa: F401
from core.app.apps.advanced_chat import app_generator as adv_app_gen_module
from core.app.apps.workflow import app_generator as wf_app_gen_module
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.entities.commands import PauseCommand
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunPausedEvent,
GraphRunSucceededEvent,
NodeRunSucceededEvent,
)
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig, VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
if "core.ops.ops_trace_manager" not in sys.modules:
ops_stub = ModuleType("core.ops.ops_trace_manager")
class _StubTraceQueueManager:
def __init__(self, *_, **__):
pass
ops_stub.TraceQueueManager = _StubTraceQueueManager
sys.modules["core.ops.ops_trace_manager"] = ops_stub
class _StubToolNodeData(BaseNodeData):
pass
class _StubToolNode(Node):
node_type = NodeType.TOOL
def init_node_data(self, data):
self._node_data = _StubToolNodeData.model_validate(data)
def _get_error_strategy(self):
return self._node_data.error_strategy
def _get_retry_config(self) -> RetryConfig:
return self._node_data.retry_config
def _get_title(self) -> str:
return self._node_data.title
def _get_description(self):
return self._node_data.desc
def _get_default_value_dict(self) -> dict[str, Any]:
return self._node_data.default_value_dict
def get_base_node_data(self) -> BaseNodeData:
return self._node_data
def _run(self):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"value": f"{self.id}-done"},
)
def _patch_tool_node(mocker):
from core.workflow.nodes import node_factory
custom_mapping = dict(node_factory.NODE_TYPE_CLASSES_MAPPING)
custom_versions = dict(custom_mapping[NodeType.TOOL])
custom_versions[node_factory.LATEST_VERSION] = _StubToolNode
custom_mapping[NodeType.TOOL] = custom_versions
mocker.patch("core.workflow.nodes.node_factory.NODE_TYPE_CLASSES_MAPPING", custom_mapping)
def _build_graph(runtime_state: GraphRuntimeState) -> Graph:
params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config={},
user_id="user",
user_from="account",
invoke_from="service-api",
call_depth=0,
)
start_data = StartNodeData(title="start", variables=[])
start_node = StartNode(
id="start",
config={"id": "start", "data": start_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
start_node.init_node_data(start_data.model_dump())
tool_data = _StubToolNodeData(title="tool")
tool_a = _StubToolNode(
id="tool_a",
config={"id": "tool_a", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_a.init_node_data(tool_data.model_dump())
tool_b = _StubToolNode(
id="tool_b",
config={"id": "tool_b", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_b.init_node_data(tool_data.model_dump())
tool_c = _StubToolNode(
id="tool_c",
config={"id": "tool_c", "data": tool_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
tool_c.init_node_data(tool_data.model_dump())
end_data = EndNodeData(
title="end",
outputs=[VariableSelector(variable="result", value_selector=["tool_c", "value"])],
desc=None,
)
end_node = EndNode(
id="end",
config={"id": "end", "data": end_data.model_dump()},
graph_init_params=params,
graph_runtime_state=runtime_state,
)
end_node.init_node_data(end_data.model_dump())
return (
Graph.new()
.add_root(start_node)
.add_node(tool_a)
.add_node(tool_b)
.add_node(tool_c)
.add_node(end_node)
.add_edge("tool_a", "tool_b")
.add_edge("tool_b", "tool_c")
.add_edge("tool_c", "end")
.build()
)
def _build_runtime_state(run_id: str) -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"),
user_inputs={},
conversation_variables=[],
)
variable_pool.system_variables.workflow_execution_id = run_id
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _run_with_optional_pause(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> list[GraphEngineEvent]:
command_channel = InMemoryChannel()
graph = _build_graph(runtime_state)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=command_channel,
)
events: list[GraphEngineEvent] = []
for event in engine.run():
if isinstance(event, NodeRunSucceededEvent) and pause_on and event.node_id == pause_on:
command_channel.send_command(PauseCommand(reason="test pause"))
engine._command_processor.process_commands() # type: ignore[attr-defined]
events.append(event)
return events
def _node_successes(events: list[GraphEngineEvent]) -> list[str]:
return [evt.node_id for evt in events if isinstance(evt, NodeRunSucceededEvent)]
def test_workflow_app_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("paused-run")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = wf_app_gen_module.WorkflowAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
graph_runtime_state=resumed_state,
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs
def test_advanced_chat_pause_resume_matches_baseline(mocker):
_patch_tool_node(mocker)
baseline_state = _build_runtime_state("adv-baseline")
baseline_events = _run_with_optional_pause(baseline_state, pause_on=None)
assert isinstance(baseline_events[-1], GraphRunSucceededEvent)
baseline_nodes = _node_successes(baseline_events)
baseline_outputs = baseline_state.outputs
paused_state = _build_runtime_state("adv-paused")
paused_events = _run_with_optional_pause(paused_state, pause_on="tool_a")
assert isinstance(paused_events[-1], GraphRunPausedEvent)
paused_nodes = _node_successes(paused_events)
snapshot = paused_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
generator = adv_app_gen_module.AdvancedChatAppGenerator()
def _fake_generate(**kwargs):
state: GraphRuntimeState = kwargs["graph_runtime_state"]
events = _run_with_optional_pause(state, pause_on=None)
return _node_successes(events)
mocker.patch.object(generator, "_generate", side_effect=_fake_generate)
resumed_nodes = generator.resume(
app_model=SimpleNamespace(mode="workflow"),
workflow=SimpleNamespace(),
user=SimpleNamespace(),
conversation=SimpleNamespace(id="conv"),
message=SimpleNamespace(id="msg"),
application_generate_entity=SimpleNamespace(stream=False, invoke_from=InvokeFrom.SERVICE_API),
workflow_execution_repository=SimpleNamespace(),
workflow_node_execution_repository=SimpleNamespace(),
graph_runtime_state=resumed_state,
)
assert paused_nodes + resumed_nodes == baseline_nodes
assert resumed_state.outputs == baseline_outputs