fix: preserve timing metrics in parallel iteration (#33216)

This commit is contained in:
盐粒 Yanli
2026-03-19 18:05:52 +08:00
committed by GitHub
parent 2b8823f38d
commit df0ded210f
13 changed files with 388 additions and 20 deletions

View File

@ -5,6 +5,7 @@ Unit tests for WorkflowResponseConverter focusing on process_data truncation fun
import uuid
from collections.abc import Mapping
from dataclasses import dataclass
from datetime import UTC, datetime
from typing import Any
from unittest.mock import Mock
@ -234,6 +235,50 @@ class TestWorkflowResponseConverter:
assert response.data.process_data == {}
assert response.data.process_data_truncated is False
def test_workflow_node_finish_response_prefers_event_finished_at(
self,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Finished timestamps should come from the event, not delayed queue processing time."""
converter = self.create_workflow_response_converter()
start_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr(
"core.app.apps.common.workflow_response_converter.naive_utc_now",
lambda: delayed_processing_time,
)
converter.workflow_start_to_stream_response(
task_id="bootstrap",
workflow_run_id="run-id",
workflow_id="wf-id",
reason=WorkflowStartReason.INITIAL,
)
event = QueueNodeSucceededEvent(
node_id="test-node-id",
node_type=BuiltinNodeTypes.CODE,
node_execution_id="node-exec-1",
start_at=start_at,
finished_at=finished_at,
in_iteration_id=None,
in_loop_id=None,
inputs={},
process_data={},
outputs={},
execution_metadata={},
)
response = converter.workflow_node_finish_to_stream_response(
event=event,
task_id="test-task-id",
)
assert response is not None
assert response.data.elapsed_time == 2.0
assert response.data.finished_at == int(finished_at.timestamp())
def test_workflow_node_retry_response_uses_truncated_process_data(self):
"""Test that node retry response uses get_response_process_data()."""
converter = self.create_workflow_response_converter()

View File

@ -0,0 +1,60 @@
from datetime import UTC, datetime
from unittest.mock import Mock
import pytest
from core.app.workflow.layers.persistence import (
PersistenceWorkflowInfo,
WorkflowPersistenceLayer,
_NodeRuntimeSnapshot,
)
from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType
from dify_graph.node_events import NodeRunResult
def _build_layer() -> WorkflowPersistenceLayer:
application_generate_entity = Mock()
application_generate_entity.inputs = {}
return WorkflowPersistenceLayer(
application_generate_entity=application_generate_entity,
workflow_info=PersistenceWorkflowInfo(
workflow_id="workflow-id",
workflow_type=WorkflowType.WORKFLOW,
version="1",
graph_data={},
),
workflow_execution_repository=Mock(),
workflow_node_execution_repository=Mock(),
)
def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.MonkeyPatch) -> None:
layer = _build_layer()
node_execution = Mock()
node_execution.id = "node-exec-1"
node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None)
node_execution.update_from_mapping = Mock()
layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot(
node_id="node-id",
title="LLM",
predecessor_node_id=None,
iteration_id="iter-1",
loop_id=None,
created_at=node_execution.created_at,
)
event_finished_at = datetime(2024, 1, 1, 0, 0, 2, tzinfo=UTC).replace(tzinfo=None)
delayed_processing_time = datetime(2024, 1, 1, 0, 0, 10, tzinfo=UTC).replace(tzinfo=None)
monkeypatch.setattr("core.app.workflow.layers.persistence.naive_utc_now", lambda: delayed_processing_time)
layer._update_node_execution(
node_execution,
NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED),
WorkflowNodeExecutionStatus.SUCCEEDED,
finished_at=event_finished_at,
)
assert node_execution.finished_at == event_finished_at
assert node_execution.elapsed_time == 2.0

View File

@ -0,0 +1,145 @@
import queue
from collections.abc import Generator
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue
from dify_graph.graph_engine.worker import Worker
from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent
def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None:
fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time)
worker = Worker(
ready_queue=InMemoryReadyQueue(),
event_queue=queue.Queue(),
graph=MagicMock(),
layers=[],
)
node = SimpleNamespace(
execution_id="exec-1",
id="node-1",
node_type=BuiltinNodeTypes.LLM,
)
event = worker._build_fallback_failure_event(node, RuntimeError("boom"))
assert event.start_at == fixed_time
assert event.finished_at == fixed_time
assert event.error == "boom"
assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
assert event.node_run_result.error == "boom"
assert event.node_run_result.error_type == "RuntimeError"
def test_worker_fallback_failure_event_reuses_observed_start_time() -> None:
start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
failure_time = start_at + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeNode:
execution_id = "exec-1"
id = "node-1"
node_type = BuiltinNodeTypes.LLM
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="LLM",
start_at=start_at,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"node-1": FakeNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["node-1"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 1:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == start_at
assert fallback_event.finished_at == failure_time
assert fallback_event.error == "queue boom"
assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED
def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None:
parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None)
child_start = parent_start + timedelta(seconds=3)
failure_time = parent_start + timedelta(seconds=5)
captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = []
class FakeIterationNode:
execution_id = "iteration-exec"
id = "iteration-node"
node_type = BuiltinNodeTypes.ITERATION
def ensure_execution_id(self) -> str:
return self.execution_id
def run(self) -> Generator[NodeRunStartedEvent, None, None]:
yield NodeRunStartedEvent(
id=self.execution_id,
node_id=self.id,
node_type=self.node_type,
node_title="Iteration",
start_at=parent_start,
)
yield NodeRunStartedEvent(
id="child-exec",
node_id="child-node",
node_type=BuiltinNodeTypes.LLM,
node_title="LLM",
start_at=child_start,
in_iteration_id=self.id,
)
worker = Worker(
ready_queue=MagicMock(),
event_queue=MagicMock(),
graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}),
layers=[],
)
worker._ready_queue.get.side_effect = ["iteration-node"]
def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None:
captured_events.append(event)
if len(captured_events) == 2:
raise RuntimeError("queue boom")
worker.stop()
worker._event_queue.put.side_effect = put_side_effect
with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time):
worker.run()
fallback_event = captured_events[-1]
assert isinstance(fallback_event, NodeRunFailedEvent)
assert fallback_event.start_at == parent_start
assert fallback_event.finished_at == failure_time

View File

@ -0,0 +1,63 @@
import time
from contextlib import nullcontext
from datetime import UTC, datetime
import pytest
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.graph_events import NodeRunSucceededEvent
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from dify_graph.nodes.iteration.iteration_node import IterationNode
def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None:
node = IterationNode.__new__(IterationNode)
node._node_data = IterationNodeData(
title="Parallel Iteration",
iterator_selector=["start", "items"],
output_selector=["iteration", "output"],
is_parallel=True,
parallel_nums=2,
error_handle_mode=ErrorHandleMode.TERMINATED,
)
node._capture_execution_context = lambda: nullcontext()
node._sync_conversation_variables_from_snapshot = lambda snapshot: None
node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new)
def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object):
return (
0.1 + (index * 0.1),
[
NodeRunSucceededEvent(
id=f"exec-{index}",
node_id=f"llm-{index}",
node_type=BuiltinNodeTypes.LLM,
start_at=datetime.now(UTC).replace(tzinfo=None),
),
],
f"output-{item}",
{},
LLMUsage.empty_usage(),
)
node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel
outputs: list[object] = []
iter_run_map: dict[str, float] = {}
usage_accumulator = [LLMUsage.empty_usage()]
generator = node._execute_parallel_iterations(
iterator_list_value=["a", "b"],
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
for _ in generator:
# Simulate a slow consumer replaying buffered events.
time.sleep(0.02)
assert outputs == ["output-a", "output-b"]
assert iter_run_map["0"] == pytest.approx(0.1)
assert iter_run_map["1"] == pytest.approx(0.2)