mirror of
https://github.com/langgenius/dify.git
synced 2026-04-26 13:45:57 +08:00
fix: preserve timing metrics in parallel iteration (#33216)
This commit is contained in:
@ -159,6 +159,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
@ -198,6 +199,7 @@ class ErrorHandler:
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
start_at=event.start_at,
|
||||
finished_at=event.finished_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.EXCEPTION,
|
||||
inputs=event.node_run_result.inputs,
|
||||
|
||||
@ -15,10 +15,13 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from dify_graph.context import IExecutionContext
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, is_node_result_event
|
||||
from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event
|
||||
from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .ready_queue import ReadyQueue
|
||||
|
||||
@ -65,6 +68,7 @@ class Worker(threading.Thread):
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
self._current_node_started_at: datetime | None = None
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Signal the worker to stop processing."""
|
||||
@ -104,18 +108,15 @@ class Worker(threading.Thread):
|
||||
self._last_task_time = time.time()
|
||||
node = self._graph.nodes[node_id]
|
||||
try:
|
||||
self._current_node_started_at = None
|
||||
self._execute_node(node)
|
||||
self._ready_queue.task_done()
|
||||
except Exception as e:
|
||||
error_event = NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=str(e),
|
||||
start_at=datetime.now(),
|
||||
self._event_queue.put(
|
||||
self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at)
|
||||
)
|
||||
self._event_queue.put(error_event)
|
||||
finally:
|
||||
self._current_node_started_at = None
|
||||
|
||||
def _execute_node(self, node: Node) -> None:
|
||||
"""
|
||||
@ -136,6 +137,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@ -149,6 +152,8 @@ class Worker(threading.Thread):
|
||||
try:
|
||||
node_events = node.run()
|
||||
for event in node_events:
|
||||
if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id:
|
||||
self._current_node_started_at = event.start_at
|
||||
self._event_queue.put(event)
|
||||
if is_node_result_event(event):
|
||||
result_event = event
|
||||
@ -177,3 +182,24 @@ class Worker(threading.Thread):
|
||||
except Exception:
|
||||
# Silently ignore layer errors to prevent disrupting node execution
|
||||
continue
|
||||
|
||||
def _build_fallback_failure_event(
|
||||
self, node: Node, error: Exception, *, started_at: datetime | None = None
|
||||
) -> NodeRunFailedEvent:
|
||||
"""Build a failed event when worker-level execution aborts before a node emits its own result event."""
|
||||
failure_time = naive_utc_now()
|
||||
error_message = str(error)
|
||||
return NodeRunFailedEvent(
|
||||
id=node.execution_id,
|
||||
node_id=node.id,
|
||||
node_type=node.node_type,
|
||||
in_iteration_id=None,
|
||||
error=error_message,
|
||||
start_at=started_at or failure_time,
|
||||
finished_at=failure_time,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
error_type=type(error).__name__,
|
||||
),
|
||||
)
|
||||
|
||||
@ -36,16 +36,19 @@ class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
class NodeRunSucceededEvent(GraphNodeEventBase):
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunFailedEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunExceptionEvent(GraphNodeEventBase):
|
||||
error: str = Field(..., description="error")
|
||||
start_at: datetime = Field(..., description="node start time")
|
||||
finished_at: datetime | None = Field(default=None, description="node finish time")
|
||||
|
||||
|
||||
class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
|
||||
@ -406,11 +406,13 @@ class Node(Generic[NodeDataT]):
|
||||
error=str(e),
|
||||
error_type="WorkflowNodeError",
|
||||
)
|
||||
finished_at = naive_utc_now()
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=str(e),
|
||||
)
|
||||
@ -568,6 +570,7 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_data
|
||||
|
||||
def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase:
|
||||
finished_at = naive_utc_now()
|
||||
match result.status:
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
return NodeRunFailedEvent(
|
||||
@ -575,6 +578,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
error=result.error,
|
||||
)
|
||||
@ -584,6 +588,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self.id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=result,
|
||||
)
|
||||
case _:
|
||||
@ -606,6 +611,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent:
|
||||
finished_at = naive_utc_now()
|
||||
match event.node_run_result.status:
|
||||
case WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
return NodeRunSucceededEvent(
|
||||
@ -613,6 +619,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
)
|
||||
case WorkflowNodeExecutionStatus.FAILED:
|
||||
@ -621,6 +628,7 @@ class Node(Generic[NodeDataT]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
finished_at=finished_at,
|
||||
node_run_result=event.node_run_result,
|
||||
error=event.node_run_result.error,
|
||||
)
|
||||
|
||||
@ -236,7 +236,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
future_to_index: dict[
|
||||
Future[
|
||||
tuple[
|
||||
datetime,
|
||||
float,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, Variable],
|
||||
@ -261,7 +261,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
try:
|
||||
result = future.result()
|
||||
(
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
@ -274,8 +274,9 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
# Yield all events from this iteration
|
||||
yield from events
|
||||
|
||||
# Update tokens and timing
|
||||
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
# The worker computes duration before we replay buffered events here,
|
||||
# so slow downstream consumers don't inflate per-iteration timing.
|
||||
iter_run_map[str(index)] = iteration_duration
|
||||
|
||||
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
|
||||
|
||||
@ -305,7 +306,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
index: int,
|
||||
item: object,
|
||||
execution_context: "IExecutionContext",
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with execution_context:
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@ -327,9 +328,10 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
conversation_snapshot = self._extract_conversation_variable_snapshot(
|
||||
variable_pool=graph_engine.graph_runtime_state.variable_pool
|
||||
)
|
||||
iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
|
||||
|
||||
return (
|
||||
iter_start_at,
|
||||
iteration_duration,
|
||||
events,
|
||||
output_value,
|
||||
conversation_snapshot,
|
||||
|
||||
Reference in New Issue
Block a user