Fix: surface workflow container LLM usage (#27021)

This commit is contained in:
-LAN-
2025-10-21 16:05:26 +08:00
committed by GitHub
parent 2bcf96565a
commit 4a6398fc1f
10 changed files with 283 additions and 59 deletions

View File

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, NewType, cast
from flask import Flask, current_app
from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
@ -34,6 +35,7 @@ from core.workflow.node_events import (
NodeRunResult,
StreamCompletedEvent,
)
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
@ -58,7 +60,7 @@ logger = logging.getLogger(__name__)
EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment)
class IterationNode(Node):
class IterationNode(LLMUsageTrackingMixin, Node):
"""
Iteration Node.
"""
@ -118,6 +120,7 @@ class IterationNode(Node):
started_at = naive_utc_now()
iter_run_map: dict[str, float] = {}
outputs: list[object] = []
usage_accumulator = [LLMUsage.empty_usage()]
yield IterationStartedEvent(
start_at=started_at,
@ -130,22 +133,27 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_success(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
)
except IterationNodeError as e:
self._accumulate_usage(usage_accumulator[0])
yield from self._handle_iteration_failure(
started_at=started_at,
inputs=inputs,
outputs=outputs,
iterator_list_value=iterator_list_value,
iter_run_map=iter_run_map,
usage=usage_accumulator[0],
error=e,
)
@ -196,6 +204,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
if self._node_data.is_parallel:
# Parallel mode execution
@ -203,6 +212,7 @@ class IterationNode(Node):
iterator_list_value=iterator_list_value,
outputs=outputs,
iter_run_map=iter_run_map,
usage_accumulator=usage_accumulator,
)
else:
# Sequential mode execution
@ -228,6 +238,9 @@ class IterationNode(Node):
# Update the total tokens from this iteration
self.graph_runtime_state.total_tokens += graph_engine.graph_runtime_state.total_tokens
usage_accumulator[0] = self._merge_usage(
usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage
)
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
def _execute_parallel_iterations(
@ -235,6 +248,7 @@ class IterationNode(Node):
iterator_list_value: Sequence[object],
outputs: list[object],
iter_run_map: dict[str, float],
usage_accumulator: list[LLMUsage],
) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]:
# Initialize outputs list with None values to maintain order
outputs.extend([None] * len(iterator_list_value))
@ -245,7 +259,16 @@ class IterationNode(Node):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit all iteration tasks
future_to_index: dict[
Future[tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]],
Future[
tuple[
datetime,
list[GraphNodeEventBase],
object | None,
int,
dict[str, VariableUnion],
LLMUsage,
]
],
int,
] = {}
for index, item in enumerate(iterator_list_value):
@ -264,7 +287,14 @@ class IterationNode(Node):
index = future_to_index[future]
try:
result = future.result()
iter_start_at, events, output_value, tokens_used, conversation_snapshot = result
(
iter_start_at,
events,
output_value,
tokens_used,
conversation_snapshot,
iteration_usage,
) = result
# Update outputs at the correct index
outputs[index] = output_value
@ -276,6 +306,8 @@ class IterationNode(Node):
self.graph_runtime_state.total_tokens += tokens_used
iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds()
usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage)
# Sync conversation variables after iteration completion
self._sync_conversation_variables_from_snapshot(conversation_snapshot)
@ -303,7 +335,7 @@ class IterationNode(Node):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion]]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, int, dict[str, VariableUnion], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@ -332,6 +364,7 @@ class IterationNode(Node):
output_value,
graph_engine.graph_runtime_state.total_tokens,
conversation_snapshot,
graph_engine.graph_runtime_state.llm_usage,
)
def _handle_iteration_success(
@ -341,6 +374,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists
flattened_outputs = self._flatten_outputs_if_needed(outputs)
@ -351,7 +386,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
)
@ -362,8 +399,11 @@ class IterationNode(Node):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": flattened_outputs},
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)
@ -400,6 +440,8 @@ class IterationNode(Node):
outputs: list[object],
iterator_list_value: Sequence[object],
iter_run_map: dict[str, float],
*,
usage: LLMUsage,
error: IterationNodeError,
) -> Generator[NodeEventBase, None, None]:
# Flatten the list of lists if all outputs are lists (even in failure case)
@ -411,7 +453,9 @@ class IterationNode(Node):
outputs={"output": flattened_outputs},
steps=len(iterator_list_value),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: self.graph_runtime_state.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
},
error=str(error),
@ -420,6 +464,12 @@ class IterationNode(Node):
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=str(error),
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
llm_usage=usage,
)
)