fix(api): prevent node from running after pausing

This commit is contained in:
QuantumGhost
2026-01-08 10:03:22 +08:00
parent 3c79bea28f
commit 2a6b6a873e
10 changed files with 787 additions and 12 deletions

View File

@ -195,9 +195,13 @@ class EventHandler:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
if self._graph_execution.is_paused:
for node_id in ready_nodes:
self._graph_runtime_state.register_deferred_node(node_id)
else:
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)

View File

@ -317,8 +317,10 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@ -334,7 +336,11 @@ class GraphEngine:
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
seen_nodes: set[str] = set()
for node_id in paused_nodes + deferred_nodes:
if node_id in seen_nodes:
continue
seen_nodes.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)

View File

@ -224,6 +224,8 @@ class GraphStateManager:
Returns:
Number of executing nodes
"""
# This count is a best-effort snapshot and can change concurrently.
# Only use it for pause-drain checks where scheduling is already frozen.
with self._lock:
return len(self._executing_nodes)

View File

@ -84,13 +84,16 @@ class Dispatcher:
"""Main dispatcher loop."""
try:
self._process_commands()
paused = False
while not self._stop_event.is_set():
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
break
if self._execution_coordinator.paused:
paused = True
break
self._execution_coordinator.check_scaling()
try:
@ -102,13 +105,10 @@ class Dispatcher:
time.sleep(0.1)
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
if paused:
self._drain_events_until_idle()
else:
self._drain_event_queue()
except Exception as e:
logger.exception("Dispatcher error")
@ -123,3 +123,24 @@ class Dispatcher:
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()
def _drain_event_queue(self) -> None:
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
def _drain_events_until_idle(self) -> None:
while not self._stop_event.is_set():
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
if not self._execution_coordinator.has_executing_nodes():
break
self._drain_event_queue()

View File

@ -94,3 +94,11 @@ class ExecutionCoordinator:
self._worker_pool.stop()
self._state_manager.clear_executing()
def has_executing_nodes(self) -> bool:
"""Return True if any nodes are currently marked as executing."""
# This check is only safe once execution has already paused.
# Before pause, executing state can change concurrently, which makes the result unreliable.
if not self._graph_execution.is_paused:
raise AssertionError("has_executing_nodes should only be called after execution is paused")
return self._state_manager.get_executing_count() > 0

View File

@ -129,6 +129,7 @@ class _GraphRuntimeStateSnapshot:
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
class GraphRuntimeState:
@ -177,6 +178,7 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Tracks nodes that are being resumed in the current execution cycle.
# Populated when paused nodes are consumed during resume.
self._resuming_nodes: set[str] = set()
@ -321,6 +323,7 @@ class GraphRuntimeState:
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
"deferred_nodes": list(self._deferred_nodes),
}
if self._response_coordinator is not None and self._graph is not None:
@ -370,6 +373,23 @@ class GraphRuntimeState:
self._resuming_nodes.update(nodes)
return nodes
def register_deferred_node(self, node_id: str) -> None:
"""Record a node that became ready during pause and should resume later."""
self._deferred_nodes.add(node_id)
def get_deferred_nodes(self) -> list[str]:
"""Retrieve deferred nodes without mutating internal state."""
return list(self._deferred_nodes)
def consume_deferred_nodes(self) -> list[str]:
"""Retrieve and clear deferred nodes awaiting resume."""
nodes = list(self._deferred_nodes)
self._deferred_nodes.clear()
return nodes
def consume_resuming_node(self, node_id: str) -> bool:
"""
Return True iff `node_id` is in the resuming set and remove it.
@ -440,6 +460,7 @@ class GraphRuntimeState:
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
return _GraphRuntimeStateSnapshot(
start_at=start_at,
@ -453,6 +474,7 @@ class GraphRuntimeState:
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
@ -468,6 +490,7 @@ class GraphRuntimeState:
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:

View File

@ -0,0 +1,73 @@
import queue
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.orchestration.dispatcher import Dispatcher
from core.workflow.graph_events import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
class StubExecutionCoordinator:
def __init__(self, paused: bool) -> None:
self._paused = paused
self.mark_complete_called = False
self.failed_error: Exception | None = None
@property
def aborted(self) -> bool:
return False
@property
def paused(self) -> bool:
return self._paused
@property
def execution_complete(self) -> bool:
return False
def check_scaling(self) -> None:
return None
def process_commands(self) -> None:
return None
def mark_complete(self) -> None:
self.mark_complete_called = True
def mark_failed(self, error: Exception) -> None:
self.failed_error = error
class StubEventHandler:
def __init__(self) -> None:
self.events: list[object] = []
def dispatch(self, event: object) -> None:
self.events.append(event)
def test_dispatcher_drains_events_when_paused() -> None:
event_queue: queue.Queue = queue.Queue()
event = NodeRunSucceededEvent(
id="exec-1",
node_id="node-1",
node_type=NodeType.START,
start_at=datetime.utcnow(),
node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
)
event_queue.put(event)
handler = StubEventHandler()
coordinator = StubExecutionCoordinator(paused=True)
dispatcher = Dispatcher(
event_queue=event_queue,
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
)
dispatcher._dispatcher_loop()
assert handler.events == [event]
assert coordinator.mark_complete_called is True

View File

@ -1,5 +1,6 @@
"""Unit tests for the execution coordinator orchestration logic."""
import pytest
from unittest.mock import MagicMock
from core.workflow.graph_engine.command_processing.command_processor import CommandProcessor
@ -48,3 +49,13 @@ def test_handle_pause_noop_when_execution_running() -> None:
worker_pool.stop.assert_not_called()
state_manager.clear_executing.assert_not_called()
def test_has_executing_nodes_requires_pause() -> None:
graph_execution = GraphExecution(workflow_id="workflow")
graph_execution.start()
coordinator, _, _ = _build_coordinator(graph_execution)
with pytest.raises(AssertionError):
coordinator.has_executing_nodes()

View File

@ -0,0 +1,326 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunPauseRequestedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None:
self._forms_by_node_id = dict(forms_by_node_id)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
return self._forms_by_node_id.get(node_id)
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in resume scenario")
class DelayedHumanInputNode(HumanInputNode):
def __init__(self, delay_seconds: float, **kwargs: Any) -> None:
super().__init__(**kwargs)
self._delay_seconds = delay_seconds
def _run(self):
if self._delay_seconds > 0:
time.sleep(self._delay_seconds)
yield from super()._run()
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Human input required",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_a_config = {"id": "human_a", "data": human_data.model_dump()}
human_a = HumanInputNode(
id=human_a_config["id"],
config=human_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
human_b_config = {"id": "human_b", "data": human_data.model_dump()}
human_b = DelayedHumanInputNode(
id=human_b_config["id"],
config=human_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
delay_seconds=0.2,
)
llm_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_config = {"id": "llm_a", "data": llm_data.model_dump()}
llm_a = MockLLMNode(
id=llm_config["id"],
config=llm_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
return (
Graph.new()
.add_root(start_node)
.add_node(human_a, from_node_id="start")
.add_node(human_b, from_node_id="start")
.add_node(llm_a, from_node_id="human_a", source_handle="approve")
.build()
)
def test_parallel_human_input_pause_preserves_node_finished() -> None:
runtime_state = _build_runtime_state()
runtime_state.graph_execution.start()
runtime_state.register_paused_node("human_a")
runtime_state.register_paused_node("human_b")
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(runtime_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
)
events = list(engine.run())
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events)
assert graph_started
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded
def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None:
base_state = _build_runtime_state()
base_state.graph_execution.start()
base_state.register_paused_node("human_a")
base_state.register_paused_node("human_b")
snapshot = base_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted = StaticForm(
form_id="form-a",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
pending = StaticForm(
form_id="form-b",
rendered="rendered",
is_submitted=False,
action_id=None,
data=None,
status_value=HumanInputFormStatus.WAITING,
)
repo = StaticRepo({"human_a": submitted, "human_b": pending})
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
graph = _build_graph(resumed_state, repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
)
events = list(engine.run())
start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent))
assert start_event.is_resumption is True
llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events)
llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events)
human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events)
graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events)
assert graph_paused
assert human_b_pause
assert llm_started
assert llm_succeeded

View File

@ -0,0 +1,301 @@
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, Mapping
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine.command_channels.in_memory_channel import InMemoryChannel
from core.workflow.graph_engine.graph_engine import GraphEngine
from core.workflow.graph_events import (
GraphRunPausedEvent,
GraphRunStartedEvent,
NodeRunStartedEvent,
NodeRunSucceededEvent,
)
from core.workflow.nodes.end.end_node import EndNode
from core.workflow.nodes.end.entities import EndNodeData
from core.workflow.nodes.human_input.entities import HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
from core.workflow.nodes.llm.entities import (
ContextConfig,
LLMNodeChatModelMessage,
LLMNodeData,
ModelConfig,
VisionConfig,
)
from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from libs.datetime_utils import naive_utc_now
from .test_mock_config import MockConfig, NodeMockConfig
from .test_mock_nodes import MockLLMNode
@dataclass
class StaticForm(HumanInputFormEntity):
form_id: str
rendered: str
is_submitted: bool
action_id: str | None = None
data: Mapping[str, Any] | None = None
status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING
expiration: datetime = naive_utc_now() + timedelta(days=1)
@property
def id(self) -> str:
return self.form_id
@property
def web_app_token(self) -> str | None:
return "token"
@property
def recipients(self) -> list:
return []
@property
def rendered_content(self) -> str:
return self.rendered
@property
def selected_action_id(self) -> str | None:
return self.action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self.data
@property
def submitted(self) -> bool:
return self.is_submitted
@property
def status(self) -> HumanInputFormStatus:
return self.status_value
@property
def expiration_time(self) -> datetime:
return self.expiration
class StaticRepo(HumanInputFormRepository):
def __init__(self, form: HumanInputFormEntity) -> None:
self._form = form
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
if node_id != "human_pause":
return None
return self._form
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
raise AssertionError("create_form should not be called in this test")
def _build_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables=SystemVariable(
user_id="user",
app_id="app",
workflow_id="workflow",
workflow_execution_id="exec-1",
),
user_inputs={},
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph:
graph_config: dict[str, object] = {"nodes": [], "edges": []}
graph_init_params = GraphInitParams(
tenant_id="tenant",
app_id="app",
workflow_id="workflow",
graph_config=graph_config,
user_id="user",
user_from="account",
invoke_from="debugger",
call_depth=0,
)
start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()}
start_node = StartNode(
id=start_config["id"],
config=start_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
llm_a_data = LLMNodeData(
title="LLM A",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt A",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()}
llm_a = MockLLMNode(
id=llm_a_config["id"],
config=llm_a_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
llm_b_data = LLMNodeData(
title="LLM B",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text="Prompt B",
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context=ContextConfig(enabled=False, variable_selector=None),
vision=VisionConfig(enabled=False),
reasoning_format="tagged",
)
llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()}
llm_b = MockLLMNode(
id=llm_b_config["id"],
config=llm_b_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
mock_config=mock_config,
)
human_data = HumanInputNodeData(
title="Human Input",
form_content="Pause here",
inputs=[],
user_actions=[UserAction(id="approve", title="Approve")],
)
human_config = {"id": "human_pause", "data": human_data.model_dump()}
human_node = HumanInputNode(
id=human_config["id"],
config=human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
form_repository=repo,
)
end_human_data = EndNodeData(title="End Human", outputs=[], desc=None)
end_human_config = {"id": "end_human", "data": end_human_data.model_dump()}
end_human = EndNode(
id=end_human_config["id"],
config=end_human_config,
graph_init_params=graph_init_params,
graph_runtime_state=runtime_state,
)
return (
Graph.new()
.add_root(start_node)
.add_node(llm_a, from_node_id="start")
.add_node(human_node, from_node_id="start")
.add_node(llm_b, from_node_id="llm_a")
.add_node(end_human, from_node_id="human_pause", source_handle="approve")
.build()
)
def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None:
for event in events:
if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id:
return event
return None
def test_pause_defers_ready_nodes_until_resume() -> None:
runtime_state = _build_runtime_state()
paused_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=False,
status_value=HumanInputFormStatus.WAITING,
)
pause_repo = StaticRepo(paused_form)
mock_config = MockConfig()
mock_config.simulate_delays = True
mock_config.set_node_config(
"llm_a",
NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5),
)
mock_config.set_node_config(
"llm_b",
NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0),
)
graph = _build_graph(runtime_state, pause_repo, mock_config)
engine = GraphEngine(
workflow_id="workflow",
graph=graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
)
paused_events = list(engine.run())
assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events)
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events)
assert _get_node_started_event(paused_events, "llm_b") is None
snapshot = runtime_state.dumps()
resumed_state = GraphRuntimeState.from_snapshot(snapshot)
submitted_form = StaticForm(
form_id="form-pause",
rendered="rendered",
is_submitted=True,
action_id="approve",
data={},
status_value=HumanInputFormStatus.SUBMITTED,
)
resume_repo = StaticRepo(submitted_form)
resumed_graph = _build_graph(resumed_state, resume_repo, mock_config)
resumed_engine = GraphEngine(
workflow_id="workflow",
graph=resumed_graph,
graph_runtime_state=resumed_state,
command_channel=InMemoryChannel(),
min_workers=2,
max_workers=2,
scale_up_threshold=1,
scale_down_idle_time=30.0,
)
resumed_events = list(resumed_engine.run())
start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent))
assert start_event.is_resumption is True
llm_b_started = _get_node_started_event(resumed_events, "llm_b")
assert llm_b_started is not None
assert llm_b_started.is_resumption is False
assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events)