feat: enhance model event handling with new identity and metrics fields

This commit is contained in:
Novice
2026-03-05 14:08:37 +08:00
parent e26d8a63da
commit 1cb5ee918f
11 changed files with 185 additions and 37 deletions

View File

@ -274,6 +274,7 @@ class TraceState(BaseModel):
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
model_segment_start_time: float | None = Field(default=None, description="Start time for current model segment")
model_start_emitted: bool = Field(default=False, description="Whether model_start has been emitted for this turn")
pending_usage: LLMUsage | None = Field(default=None, description="Pending usage for current model segment")

View File

@ -105,7 +105,7 @@ from core.workflow.node_events import (
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from core.workflow.node_events.node import ThoughtEndChunkEvent, ThoughtStartChunkEvent
from core.workflow.node_events.node import ChunkType, ThoughtEndChunkEvent, ThoughtStartChunkEvent
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
@ -2277,23 +2277,42 @@ class LLMNode(Node[LLMNodeData]):
except Exception:
return None
def _emit_model_start(self, trace_state: TraceState) -> Generator[NodeEventBase, None, None]:
"""Yield a MODEL_START event with model identity info at the beginning of a model turn.
Idempotent: only emits once per turn (guarded by trace_state.model_start_emitted)."""
if trace_state.model_start_emitted:
return
trace_state.model_start_emitted = True
if trace_state.model_segment_start_time is None:
trace_state.model_segment_start_time = time.perf_counter()
provider = self._node_data.model.provider
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
chunk_type=ChunkType.MODEL_START,
is_final=False,
model_provider=provider,
model_name=self._node_data.model.name,
model_icon=self._generate_model_provider_icon_url(provider),
model_icon_dark=self._generate_model_provider_icon_url(provider, dark=True),
)
def _flush_model_segment(
self,
buffers: StreamBuffers,
trace_state: TraceState,
error: str | None = None,
) -> None:
"""Flush pending thought/content buffers into a single model trace segment."""
) -> Generator[NodeEventBase, None, None]:
"""Flush pending thought/content buffers into a single model trace segment
and yield a MODEL_END chunk event with usage/duration metrics."""
if not buffers.pending_thought and not buffers.pending_content and not buffers.pending_tool_calls:
return
now = time.perf_counter()
duration = now - trace_state.model_segment_start_time if trace_state.model_segment_start_time else 0.0
# Use pending_usage from trace_state (captured from THOUGHT log)
usage = trace_state.pending_usage
# Generate model provider icon URL
provider = self._node_data.model.provider
model_name = self._node_data.model.name
model_icon = self._generate_model_provider_icon_url(provider)
@ -2317,10 +2336,21 @@ class LLMNode(Node[LLMNodeData]):
status="error" if error else "success",
)
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
chunk_type=ChunkType.MODEL_END,
is_final=False,
model_usage=usage,
model_duration=duration,
)
buffers.pending_thought.clear()
buffers.pending_content.clear()
buffers.pending_tool_calls.clear()
trace_state.model_segment_start_time = None
trace_state.model_start_emitted = False
trace_state.pending_usage = None
def _handle_agent_log_output(
@ -2356,18 +2386,18 @@ class LLMNode(Node[LLMNodeData]):
trace_state.pending_usage = llm_usage
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
yield from self._emit_model_start(trace_state)
tool_name = payload.tool_name
tool_call_id = payload.tool_call_id
tool_arguments = json.dumps(payload.tool_args or {})
# Get icon from metadata (available at START)
tool_icon = output.metadata.get(AgentLog.LogMetadata.ICON) if output.metadata else None
tool_icon_dark = output.metadata.get(AgentLog.LogMetadata.ICON_DARK) if output.metadata else None
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
# Add tool call to pending list for model segment
buffers.pending_tool_calls.append(ToolCall(id=tool_call_id, name=tool_name, arguments=tool_arguments))
yield ToolCallChunkEvent(
@ -2395,7 +2425,7 @@ class LLMNode(Node[LLMNodeData]):
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
# Flush model segment before tool result processing
self._flush_model_segment(buffers, trace_state)
yield from self._flush_model_segment(buffers, trace_state)
if output.status == AgentLog.LogStatus.ERROR:
tool_error = output.error or payload.tool_error
@ -2450,6 +2480,7 @@ class LLMNode(Node[LLMNodeData]):
elapsed_time=elapsed_time,
icon=tool_icon,
icon_dark=tool_icon_dark,
provider=tool_provider,
),
is_final=False,
)
@ -2474,9 +2505,7 @@ class LLMNode(Node[LLMNodeData]):
if not segment and kind not in {"thought_start", "thought_end"}:
continue
# Start tracking model segment time on first output
if trace_state.model_segment_start_time is None:
trace_state.model_segment_start_time = time.perf_counter()
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
@ -2525,9 +2554,7 @@ class LLMNode(Node[LLMNodeData]):
if not segment and kind not in {"thought_start", "thought_end"}:
continue
# Start tracking model segment time on first output
if trace_state.model_segment_start_time is None:
trace_state.model_segment_start_time = time.perf_counter()
yield from self._emit_model_start(trace_state)
if kind == "thought_start":
yield ThoughtStartChunkEvent(
@ -2572,7 +2599,7 @@ class LLMNode(Node[LLMNodeData]):
trace_state.pending_usage = aggregate.usage
# Flush final model segment
self._flush_model_segment(buffers, trace_state)
yield from self._flush_model_segment(buffers, trace_state)
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
yield StreamChunkEvent(
@ -2612,6 +2639,16 @@ class LLMNode(Node[LLMNodeData]):
),
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_start"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "model_end"],
chunk="",
is_final=True,
)
def _build_generation_data(
self,