mirror of
https://github.com/langgenius/dify.git
synced 2026-03-07 16:45:58 +08:00
fix(api): Ensure is_resumption for node_started event is correctly set
This commit is contained in:
@ -9,10 +9,8 @@ from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from models.human_input import RecipientType
|
||||
from models.model import App, EndUser
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -23,6 +21,10 @@ def _jsonify_form_definition(form: Form) -> Response:
|
||||
return Response(form.get_definition().model_dump_json(), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
# form api temporarily
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
# class HumanInputFormApi(WebApiResource):
|
||||
class HumanInputFormApi(Resource):
|
||||
|
||||
@ -319,7 +319,10 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
# FIXME(-LAN-): only for ToolNode, need to refactor
|
||||
provider_type: str # should be a core.tools.entities.tool_entities.ToolProviderType
|
||||
provider_id: str
|
||||
is_resumption: bool = False
|
||||
is_resumption: bool = Field(
|
||||
default=False,
|
||||
description="True only when this node had already started and execution resumed after a pause.",
|
||||
)
|
||||
|
||||
|
||||
class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
|
||||
@ -236,3 +236,10 @@ class GraphExecution:
|
||||
def record_node_failure(self) -> None:
|
||||
"""Increment the count of node failures encountered during execution."""
|
||||
self.exceptions_count += 1
|
||||
|
||||
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
|
||||
"""Return True if the node is resuming a previously started execution."""
|
||||
node_execution = self.node_executions.get(node_id)
|
||||
if not node_execution or not node_execution.execution_id:
|
||||
return False
|
||||
return str(node_execution.execution_id) == execution_id
|
||||
|
||||
@ -131,9 +131,6 @@ class EventHandler:
|
||||
node_execution.mark_started(event.id)
|
||||
self._graph_runtime_state.increment_node_run_steps()
|
||||
|
||||
# Mark whether this start is part of a resume flow
|
||||
event.is_resumption = self._graph_runtime_state.consume_resuming_node(event.node_id)
|
||||
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
|
||||
@ -15,7 +15,10 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
predecessor_node_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
is_resumption: bool = False
|
||||
is_resumption: bool = Field(
|
||||
default=False,
|
||||
description="True only when this node had already started and execution resumed after a pause.",
|
||||
)
|
||||
|
||||
# FIXME(-LAN-): only for ToolNode
|
||||
provider_type: str = ""
|
||||
|
||||
@ -301,6 +301,7 @@ class Node(Generic[NodeDataT]):
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
is_resumption = self.graph_runtime_state.is_node_resumption(self._node_id, execution_id)
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
@ -310,6 +311,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_title=self.title,
|
||||
in_iteration_id=None,
|
||||
start_at=self._start_at,
|
||||
is_resumption=is_resumption,
|
||||
)
|
||||
|
||||
# === FIXME(-LAN-): Needs to refactor.
|
||||
|
||||
@ -79,6 +79,10 @@ class GraphExecutionProtocol(Protocol):
|
||||
"""Record an unrecoverable error and end execution."""
|
||||
...
|
||||
|
||||
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
|
||||
"""Return True if the node is resuming a previously started execution."""
|
||||
...
|
||||
|
||||
def dumps(self) -> str:
|
||||
"""Serialize execution state into a JSON payload."""
|
||||
...
|
||||
@ -179,8 +183,11 @@ class GraphRuntimeState:
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
# Tracks nodes that are being resumed in the current execution cycle.
|
||||
# Populated when paused nodes are consumed during resume.
|
||||
# Semantic meaning:
|
||||
# A node id in this set represents "the same node execution is continuing after a pause".
|
||||
# It means the node has already started in a previous cycle, was paused, and is now resuming,
|
||||
# so its next node_started event should be marked as a resumption.
|
||||
# It does NOT mean "any node that runs after resume", and excludes never-run nodes.
|
||||
self._resuming_nodes: set[str] = set()
|
||||
|
||||
if graph is not None:
|
||||
@ -399,6 +406,14 @@ class GraphRuntimeState:
|
||||
return True
|
||||
return False
|
||||
|
||||
def is_node_resumption(self, node_id: str, execution_id: str) -> bool:
|
||||
"""
|
||||
Return True if the node is resuming a previously started execution.
|
||||
"""
|
||||
if self.consume_resuming_node(node_id):
|
||||
return True
|
||||
return self.graph_execution.is_node_resumption(node_id, execution_id)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@ -2,6 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.domain.graph_execution import GraphExecution
|
||||
@ -119,22 +123,43 @@ def test_retry_does_not_emit_additional_start_event() -> None:
|
||||
assert node_execution.retry_count == 1
|
||||
|
||||
|
||||
def test_node_start_marks_resumption_when_resuming_node() -> None:
|
||||
"""Ensure NodeRunStartedEvent is annotated with is_resumption when resuming."""
|
||||
@dataclass(frozen=True)
|
||||
class _ResumptionFlagCase:
|
||||
node_id: str
|
||||
execution_id: str
|
||||
node_title: str
|
||||
is_resumption: bool
|
||||
|
||||
node_id = "resumed-node"
|
||||
handler, event_manager, _ = _build_event_handler(node_id)
|
||||
|
||||
# Simulate paused node being consumed for resume
|
||||
handler._graph_runtime_state.register_paused_node(node_id)
|
||||
handler._graph_runtime_state.consume_paused_nodes()
|
||||
@pytest.mark.parametrize(
|
||||
"case",
|
||||
[
|
||||
_ResumptionFlagCase(
|
||||
node_id="resumed-node",
|
||||
execution_id="exec-1",
|
||||
node_title="Resumed Node",
|
||||
is_resumption=True,
|
||||
),
|
||||
_ResumptionFlagCase(
|
||||
node_id="fresh-node",
|
||||
execution_id="exec-2",
|
||||
node_title="Fresh Node",
|
||||
is_resumption=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_node_start_preserves_resumption_flag(case: _ResumptionFlagCase) -> None:
|
||||
"""Ensure NodeRunStartedEvent preserves resumption flag."""
|
||||
|
||||
handler, event_manager, _ = _build_event_handler(case.node_id)
|
||||
|
||||
start_event = NodeRunStartedEvent(
|
||||
id="exec-1",
|
||||
node_id=node_id,
|
||||
id=case.execution_id,
|
||||
node_id=case.node_id,
|
||||
node_type=NodeType.CODE,
|
||||
node_title="Resumed Node",
|
||||
node_title=case.node_title,
|
||||
start_at=naive_utc_now(),
|
||||
is_resumption=case.is_resumption,
|
||||
)
|
||||
handler.dispatch(start_event)
|
||||
|
||||
@ -142,7 +167,7 @@ def test_node_start_marks_resumption_when_resuming_node() -> None:
|
||||
assert len(collected) == 1
|
||||
emitted_event = collected[0]
|
||||
assert isinstance(emitted_event, NodeRunStartedEvent)
|
||||
assert emitted_event.is_resumption is True
|
||||
assert emitted_event.is_resumption is case.is_resumption
|
||||
|
||||
|
||||
def test_node_start_marks_fresh_run_as_not_resumption() -> None:
|
||||
|
||||
@ -207,6 +207,8 @@ def test_engine_resume_restores_state_and_completion():
|
||||
assert paused_human_started is not None
|
||||
assert resumed_human_started is not None
|
||||
assert paused_human_started.id == resumed_human_started.id
|
||||
assert paused_human_started.is_resumption is False
|
||||
assert resumed_human_started.is_resumption is True
|
||||
|
||||
assert baseline_state.outputs == resumed_state.outputs
|
||||
assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value(
|
||||
|
||||
Reference in New Issue
Block a user