fix(api): Ensure is_resumption for node_started event is correctly set

This commit is contained in:
QuantumGhost
2026-01-13 09:25:44 +08:00
parent 5523df6023
commit 6bcd4ad740
9 changed files with 76 additions and 20 deletions

View File

@ -9,10 +9,8 @@ from flask_restx import Resource, reqparse
from controllers.web import web_ns
from controllers.web.error import NotFoundError
from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from models.human_input import RecipientType
from models.model import App, EndUser
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
@ -23,6 +21,10 @@ def _jsonify_form_definition(form: Form) -> Response:
return Response(form.get_definition().model_dump_json(), mimetype="application/json")
# TODO(QuantumGhost): disable authorization for web app
# form api temporarily
@web_ns.route("/form/human_input/<string:form_token>")
# class HumanInputFormApi(WebApiResource):
class HumanInputFormApi(Resource):

View File

@ -319,7 +319,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
# FIXME(-LAN-): only for ToolNode, need to refactor
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
provider_id: str
is_resumption: bool = False
is_resumption: bool = Field(
default=False,
description="True only when this node had already started and execution resumed after a pause.",
)
class QueueNodeSucceededEvent(AppQueueEvent):

View File

@ -236,3 +236,10 @@ class GraphExecution:
def record_node_failure(self) -> None:
"""Increment the count of node failures encountered during execution."""
self.exceptions_count += 1
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
"""Return True if the node is resuming a previously started execution."""
node_execution = self.node_executions.get(node_id)
if not node_execution or not node_execution.execution_id:
return False
return str(node_execution.execution_id) == execution_id

View File

@ -131,9 +131,6 @@ class EventHandler:
node_execution.mark_started(event.id)
self._graph_runtime_state.increment_node_run_steps()
# Mark whether this start is part of a resume flow
event.is_resumption = self._graph_runtime_state.consume_resuming_node(event.node_id)
# Track in response coordinator for stream ordering
self._response_coordinator.track_node_execution(event.node_id, event.id)

View File

@ -15,7 +15,10 @@ class NodeRunStartedEvent(GraphNodeEventBase):
predecessor_node_id: str | None = None
agent_strategy: AgentNodeStrategyInit | None = None
start_at: datetime = Field(..., description="node start time")
is_resumption: bool = False
is_resumption: bool = Field(
default=False,
description="True only when this node had already started and execution resumed after a pause.",
)
# FIXME(-LAN-): only for ToolNode
provider_type: str = ""

View File

@ -301,6 +301,7 @@ class Node(Generic[NodeDataT]):
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
is_resumption = self.graph_runtime_state.is_node_resumption(self._node_id, execution_id)
# Create and push start event with required fields
start_event = NodeRunStartedEvent(
@ -310,6 +311,7 @@ class Node(Generic[NodeDataT]):
node_title=self.title,
in_iteration_id=None,
start_at=self._start_at,
is_resumption=is_resumption,
)
# === FIXME(-LAN-): Needs to refactor.

View File

@ -79,6 +79,10 @@ class GraphExecutionProtocol(Protocol):
"""Record an unrecoverable error and end execution."""
...
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
"""Return True if the node is resuming a previously started execution."""
...
def dumps(self) -> str:
"""Serialize execution state into a JSON payload."""
...
@ -179,8 +183,11 @@ class GraphRuntimeState:
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Tracks nodes that are being resumed in the current execution cycle.
# Populated when paused nodes are consumed during resume.
# Semantic meaning:
# A node id in this set represents "the same node execution is continuing after a pause".
# It means the node has already started in a previous cycle, was paused, and is now resuming,
# so its next node_started event should be marked as a resumption.
# It does NOT mean "any node that runs after resume", and excludes never-run nodes.
self._resuming_nodes: set[str] = set()
if graph is not None:
@ -399,6 +406,14 @@ class GraphRuntimeState:
return True
return False
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
"""
Return True if the node is resuming a previously started execution.
"""
if self.consume_resuming_node(node_id):
return True
return self.graph_execution.is_node_resumption(node_id, execution_id)
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------

View File

@ -2,6 +2,10 @@
from __future__ import annotations
from dataclasses import dataclass
import pytest
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
@ -119,22 +123,43 @@ def test_retry_does_not_emit_additional_start_event() -> None:
assert node_execution.retry_count == 1
def test_node_start_marks_resumption_when_resuming_node() -> None:
"""Ensure NodeRunStartedEvent is annotated with is_resumption when resuming."""
@dataclass(frozen=True)
class _ResumptionFlagCase:
node_id: str
execution_id: str
node_title: str
is_resumption: bool
node_id = "resumed-node"
handler, event_manager, _ = _build_event_handler(node_id)
# Simulate paused node being consumed for resume
handler._graph_runtime_state.register_paused_node(node_id)
handler._graph_runtime_state.consume_paused_nodes()
@pytest.mark.parametrize(
"case",
[
_ResumptionFlagCase(
node_id="resumed-node",
execution_id="exec-1",
node_title="Resumed Node",
is_resumption=True,
),
_ResumptionFlagCase(
node_id="fresh-node",
execution_id="exec-2",
node_title="Fresh Node",
is_resumption=False,
),
],
)
def test_node_start_preserves_resumption_flag(case: _ResumptionFlagCase) -> None:
"""Ensure NodeRunStartedEvent preserves resumption flag."""
handler, event_manager, _ = _build_event_handler(case.node_id)
start_event = NodeRunStartedEvent(
id="exec-1",
node_id=node_id,
id=case.execution_id,
node_id=case.node_id,
node_type=NodeType.CODE,
node_title="Resumed Node",
node_title=case.node_title,
start_at=naive_utc_now(),
is_resumption=case.is_resumption,
)
handler.dispatch(start_event)
@ -142,7 +167,7 @@ def test_node_start_marks_resumption_when_resuming_node() -> None:
assert len(collected) == 1
emitted_event = collected[0]
assert isinstance(emitted_event, NodeRunStartedEvent)
assert emitted_event.is_resumption is True
assert emitted_event.is_resumption is case.is_resumption
def test_node_start_marks_fresh_run_as_not_resumption() -> None:

View File

@ -207,6 +207,8 @@ def test_engine_resume_restores_state_and_completion():
assert paused_human_started is not None
assert resumed_human_started is not None
assert paused_human_started.id == resumed_human_started.id
assert paused_human_started.is_resumption is False
assert resumed_human_started.is_resumption is True
assert baseline_state.outputs == resumed_state.outputs
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(