Merge branch 'feat/agent-node-v2' into deploy/dev

This commit is contained in:
Novice
2025-12-30 13:43:40 +08:00
61 changed files with 6527 additions and 1246 deletions

View File

@ -1,11 +1,16 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,33 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")

View File

@ -247,6 +247,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
LLM_TRACE = "llm_trace"
COMPLETED_REASON = "completed_reason" # completed reason for loop node

View File

@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
Args:
node_id: The node ID to attribute the event to
execution_id: The execution ID for this node
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
For object-type variables, automatically streams all child fields that have stream events.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
active_session = self._active_session
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if there's a direct stream for this selector
has_direct_stream = (
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
if stream_targets:
all_complete = True
for target_selector in stream_targets:
while self._has_unread_stream(target_selector):
if event := self._pop_stream_chunk(target_selector):
events.append(
self._rewrite_stream_event(
event=event,
output_node_id=output_node_id,
execution_id=execution_id,
special_selector=bool(special_selector),
)
)
if not self._is_stream_closed(target_selector):
all_complete = False
is_complete = all_complete
# Fallback: check if scalar value exists in variable pool
if not is_complete and not has_direct_stream:
if value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session
and self._active_session.index == len(self._active_session.template.segments) - 1
)
)
is_complete = True
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
)
is_complete = True
return events, is_complete
def _rewrite_stream_event(
self,
event: NodeRunStreamChunkEvent,
output_node_id: str,
execution_id: str,
special_selector: bool,
) -> NodeRunStreamChunkEvent:
"""Rewrite event to attribute to active response node when selector is special."""
if not special_selector:
return event
return self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
# ============= Internal Stream Management Methods =============
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
"""Find all child stream selectors that are descendants of the parent selector.
For example, if parent_selector is ['llm', 'generation'], this will find:
- ['llm', 'generation', 'content']
- ['llm', 'generation', 'tool_calls']
- ['llm', 'generation', 'tool_results']
- ['llm', 'generation', 'thought']
Args:
parent_selector: The parent selector to search for children
Returns:
List of child selector tuples found in stream buffers or closed streams
"""
parent_key = tuple(parent_selector)
parent_len = len(parent_key)
child_streams: set[tuple[str, ...]] = set()
# Search in both active buffers and closed streams
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
for selector_key in all_selectors:
# Check if this selector is a direct child of the parent
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
child_streams.add(selector_key)
return sorted(child_streams)
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

View File

@ -36,6 +36,7 @@ from .loop import (
# Node events
from .node import (
ChunkType,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
@ -44,10 +45,13 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
"BaseGraphEvent",
"ChunkType",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
@ -73,4 +77,6 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@ -1,10 +1,11 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@ -21,13 +22,37 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_id: str = ""
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
class NodeRunStreamChunkEvent(GraphNodeEventBase):
# Spec-compliant fields
"""Stream chunk event for workflow node execution."""
# Base fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@ -13,16 +13,21 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
ChunkType,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
__all__ = [
"AgentLogEvent",
"ChunkType",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",
@ -39,4 +44,7 @@ __all__ = [
"RunRetryEvent",
"StreamChunkEvent",
"StreamCompletedEvent",
"ThoughtChunkEvent",
"ToolCallChunkEvent",
"ToolResultChunkEvent",
]

View File

@ -1,11 +1,13 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@ -32,13 +34,46 @@ class RunRetryEvent(NodeEventBase):
start_at: datetime = Field(..., description="Retry start time")
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
class StreamChunkEvent(NodeEventBase):
# Spec-compliant fields
"""Base stream chunk event - normal text streaming output."""
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtChunkEvent(StreamChunkEvent):
"""Agent thought streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
class StreamCompletedEvent(NodeEventBase):

View File

@ -46,6 +46,9 @@ from core.workflow.node_events import (
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from core.workflow.runtime import GraphRuntimeState
from libs.datetime_utils import naive_utc_now
@ -543,6 +546,8 @@ class Node(Generic[NodeDataT]):
@_dispatch.register
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self.execution_id,
node_id=self._node_id,
@ -550,6 +555,65 @@ class Node(Generic[NodeDataT]):
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
@_dispatch.register
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_CALL,
tool_call=event.tool_call,
)
@_dispatch.register
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.entities import ToolResult, ToolResultStatus
from core.workflow.graph_events import ChunkType
tool_result = event.tool_result
status: ToolResultStatus = (
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
)
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_RESULT,
tool_result=ToolResult(
id=tool_result.id if tool_result else None,
name=tool_result.name if tool_result else None,
output=tool_result.output if tool_result else None,
files=tool_result.files if tool_result else [],
status=status,
),
)
@_dispatch.register
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
node_id=self._node_id,
node_type=self.node_type,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.THOUGHT,
)
@_dispatch.register

View File

@ -3,6 +3,7 @@ from .entities import (
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
ModelConfig,
ToolMetadata,
VisionConfig,
)
from .node import LLMNode
@ -13,5 +14,6 @@ __all__ = [
"LLMNodeCompletionModelPromptTemplate",
"LLMNodeData",
"ModelConfig",
"ToolMetadata",
"VisionConfig",
]

View File

@ -1,10 +1,17 @@
import re
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCallResult
from core.workflow.node_events import AgentLogEvent
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@ -58,6 +65,235 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
jinja2_text: str | None = None
class ToolMetadata(BaseModel):
"""
Tool metadata for LLM node with tool support.
Defines the essential fields needed for tool configuration,
particularly the 'type' field to identify tool provider type.
"""
# Core fields
enabled: bool = True
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
provider_name: str = Field(..., description="Tool provider name/identifier")
tool_name: str = Field(..., description="Tool name")
# Optional fields
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
# Configuration fields
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
class LLMTraceSegment(BaseModel):
"""
Streaming trace segment for LLM tool-enabled runs.
Order is preserved for replay. Tool calls are single entries containing both
arguments and results.
"""
type: Literal["thought", "content", "tool_call"]
# Common optional fields
text: str | None = Field(None, description="Text chunk for thought/content")
# Tool call fields (combined start + result)
tool_call: ToolCallResult | None = Field(
default=None,
description="Combined tool call arguments and result for this segment",
)
class LLMGenerationData(BaseModel):
"""Generation data from LLM invocation with tools.
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
- reasoning_contents: [thought1, thought2, ...] - one element per turn
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
"""
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")
files: list[File] = Field(default_factory=list, description="Generated files")
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
class ThinkTagStreamParser:
"""Lightweight state machine to split streaming chunks by <think> tags."""
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
_START_PREFIX = "<think"
_END_PREFIX = "</think"
def __init__(self):
self._buffer = ""
self._in_think = False
@staticmethod
def _suffix_prefix_len(text: str, prefix: str) -> int:
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
max_len = min(len(text), len(prefix) - 1)
for i in range(max_len, 0, -1):
if text[-i:].lower() == prefix[:i].lower():
return i
return 0
def process(self, chunk: str) -> list[tuple[str, str]]:
"""
Split incoming chunk into ('thought' | 'text', content) tuples.
Content excludes the <think> tags themselves and handles split tags across chunks.
"""
parts: list[tuple[str, str]] = []
self._buffer += chunk
while self._buffer:
if self._in_think:
end_match = self._END_PATTERN.search(self._buffer)
if end_match:
thought_text = self._buffer[: end_match.start()]
if thought_text:
parts.append(("thought", thought_text))
self._buffer = self._buffer[end_match.end() :]
self._in_think = False
continue
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("thought", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
start_match = self._START_PATTERN.search(self._buffer)
if start_match:
prefix = self._buffer[: start_match.start()]
if prefix:
parts.append(("text", prefix))
self._buffer = self._buffer[start_match.end() :]
self._in_think = True
continue
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
emit = self._buffer[: len(self._buffer) - hold_len]
if emit:
parts.append(("text", emit))
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
break
cleaned_parts: list[tuple[str, str]] = []
for kind, content in parts:
# Extra safeguard: strip any stray tags that slipped through.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
if content:
cleaned_parts.append((kind, content))
return cleaned_parts
def flush(self) -> list[tuple[str, str]]:
"""Flush remaining buffer when the stream ends."""
if not self._buffer:
return []
kind = "thought" if self._in_think else "text"
content = self._buffer
# Drop dangling partial tags instead of emitting them
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
content = ""
self._buffer = ""
if not content:
return []
# Strip any complete tags that might still be present.
content = self._START_PATTERN.sub("", content)
content = self._END_PATTERN.sub("", content)
return [(kind, content)] if content else []
class StreamBuffers(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
pending_thought: list[str] = Field(default_factory=list)
pending_content: list[str] = Field(default_factory=list)
current_turn_reasoning: list[str] = Field(default_factory=list)
reasoning_per_turn: list[str] = Field(default_factory=list)
class TraceState(BaseModel):
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
class AggregatedResult(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
text: str = ""
files: list[File] = Field(default_factory=list)
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
finish_reason: str | None = None
class AgentContext(BaseModel):
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
agent_result: AgentResult | None = None
class ToolOutputState(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
stream: StreamBuffers = Field(default_factory=StreamBuffers)
trace: TraceState = Field(default_factory=TraceState)
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
agent: AgentContext = Field(default_factory=AgentContext)
class ToolLogPayload(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
tool_name: str = ""
tool_call_id: str = ""
tool_args: dict[str, Any] = Field(default_factory=dict)
tool_output: Any = None
tool_error: Any = None
files: list[Any] = Field(default_factory=list)
meta: dict[str, Any] = Field(default_factory=dict)
@classmethod
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
data = log.data or {}
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
@classmethod
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
return cls(
tool_name=data.get("tool_name", ""),
tool_call_id=data.get("tool_call_id", ""),
tool_args=data.get("tool_args") or {},
tool_output=data.get("output"),
tool_error=data.get("error"),
files=data.get("files") or [],
meta=data.get("meta") or {},
)
class LLMNodeData(BaseNodeData):
model: ModelConfig
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
@ -86,6 +322,10 @@ class LLMNodeData(BaseNodeData):
),
)
# Tool support
tools: Sequence[ToolMetadata] = Field(default_factory=list)
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):

View File

@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext
from core.agent.patterns import StrategyFactory
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import File, FileTransferMethod, FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
@ -46,7 +48,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.__base.tool import Tool
from core.tools.signature import sign_upload_file
from core.tools.tool_manager import ToolManager
from core.variables import (
ArrayFileSegment,
ArraySegment,
@ -56,7 +60,8 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from core.workflow.entities.tool_entities import ToolCallResult
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -64,12 +69,16 @@ from core.workflow.enums import (
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import (
AgentLogEvent,
ModelInvokeCompletedEvent,
NodeEventBase,
NodeRunResult,
RunRetrieverResourceEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
@ -81,10 +90,19 @@ from models.model import UploadFile
from . import llm_utils
from .entities import (
AgentContext,
AggregatedResult,
LLMGenerationData,
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
LLMTraceSegment,
ModelConfig,
StreamBuffers,
ThinkTagStreamParser,
ToolLogPayload,
ToolOutputState,
TraceState,
)
from .exc import (
InvalidContextStructureError,
@ -149,11 +167,11 @@ class LLMNode(Node[LLMNodeData]):
def _run(self) -> Generator:
node_inputs: dict[str, Any] = {}
process_data: dict[str, Any] = {}
result_text = ""
clean_text = ""
usage = LLMUsage.empty_usage()
finish_reason = None
reasoning_content = None
reasoning_content = "" # Initialize as empty string for consistency
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
variable_pool = self.graph_runtime_state.variable_pool
try:
@ -234,55 +252,58 @@ class LLMNode(Node[LLMNodeData]):
context_files=context_files,
)
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=self.node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=self.node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self.node_data.reasoning_format,
)
# Variables for outputs
generation_data: LLMGenerationData | None = None
structured_output: LLMStructuredOutput | None = None
for event in generator:
if isinstance(event, StreamChunkEvent):
yield event
elif isinstance(event, ModelInvokeCompletedEvent):
# Raw text
result_text = event.text
usage = event.usage
finish_reason = event.finish_reason
reasoning_content = event.reasoning_content or ""
# Check if tools are configured
if self.tool_call_enabled:
# Use tool-enabled invocation (Agent V2 style)
generator = self._invoke_llm_with_tools(
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
files=files,
variable_pool=variable_pool,
node_inputs=node_inputs,
process_data=process_data,
)
else:
# Use traditional LLM invocation
generator = LLMNode.invoke_llm(
node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self._node_data.structured_output_enabled,
structured_output=self._node_data.structured_output,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
reasoning_format=self._node_data.reasoning_format,
)
# For downstream nodes, determine clean text based on reasoning_format
if self.node_data.reasoning_format == "tagged":
# Keep <think> tags for backward compatibility
clean_text = result_text
else:
# Extract clean text from <think> tags
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
(
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
# Process structured output if available from the event.
structured_output = (
LLMStructuredOutput(structured_output=event.structured_output)
if event.structured_output
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
# Extract variables from generation_data if available
if generation_data:
clean_text = generation_data.text
reasoning_content = ""
usage = generation_data.usage
finish_reason = generation_data.finish_reason
# Unified process_data building
process_data = {
"model_mode": model_config.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
@ -293,24 +314,88 @@ class LLMNode(Node[LLMNodeData]):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
if self.tool_call_enabled and self._node_data.tools:
process_data["tools"] = [
{
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
"provider_name": tool.provider_name,
"tool_name": tool.tool_name,
}
for tool in self._node_data.tools
if tool.enabled
]
# Unified outputs building
outputs = {
"text": clean_text,
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
}
# Build generation field
if generation_data:
# Use generation_data from tool invocation (supports multi-turn)
generation = {
"content": generation_data.text,
"reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...]
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
"sequence": generation_data.sequence,
}
files_to_output = generation_data.files
else:
# Traditional LLM invocation
generation_reasoning = generation_reasoning_content or reasoning_content
generation_content = generation_clean_content or clean_text
sequence: list[dict[str, Any]] = []
if generation_reasoning:
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
generation = {
"content": generation_content,
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
"tool_calls": [],
"sequence": sequence,
}
files_to_output = self._file_outputs
outputs["generation"] = generation
if files_to_output:
outputs["files"] = ArrayFileSegment(value=files_to_output)
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
if not self.tool_call_enabled:
# For tool calls, final events are already sent in _process_tool_outputs
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
}
if generation_data and generation_data.trace:
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
segment.model_dump() for segment in generation_data.trace
]
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
@ -318,11 +403,7 @@ class LLMNode(Node[LLMNodeData]):
inputs=node_inputs,
process_data=process_data,
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},
metadata=metadata,
llm_usage=usage,
)
)
@ -444,6 +525,8 @@ class LLMNode(Node[LLMNodeData]):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
think_parser = ThinkTagStreamParser()
reasoning_chunks: list[str] = []
# Initialize streaming metrics tracking
start_time = request_start_time if request_start_time is not None else time.perf_counter()
@ -472,12 +555,32 @@ class LLMNode(Node[LLMNodeData]):
has_content = True
full_text_buffer.write(text_part)
# Text output: always forward raw chunk (keep <think> tags intact)
yield StreamChunkEvent(
selector=[node_id, "text"],
chunk=text_part,
is_final=False,
)
# Generation output: split out thoughts, forward only non-thought content chunks
for kind, segment in think_parser.process(text_part):
if not segment:
continue
if kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Update the whole metadata
if not model and result.model:
model = result.model
@ -492,16 +595,35 @@ class LLMNode(Node[LLMNodeData]):
except OutputParserError as e:
raise LLMNodeError(f"Failed to parse structured output: {e}")
for kind, segment in think_parser.flush():
if not segment:
continue
if kind == "thought":
reasoning_chunks.append(segment)
yield ThoughtChunkEvent(
selector=[node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
yield StreamChunkEvent(
selector=[node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
# Extract reasoning content from <think> tags in the main text
full_text = full_text_buffer.getvalue()
if reasoning_format == "tagged":
# Keep <think> tags in text for backward compatibility
clean_text = full_text
reasoning_content = ""
reasoning_content = "".join(reasoning_chunks)
else:
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
if reasoning_chunks and not reasoning_content:
reasoning_content = "".join(reasoning_chunks)
# Calculate streaming metrics
end_time = time.perf_counter()
@ -1266,6 +1388,635 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def tool_call_enabled(self) -> bool:
return (
self.node_data.tools is not None
and len(self.node_data.tools) > 0
and all(tool.enabled for tool in self.node_data.tools)
)
def _stream_llm_events(
self,
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
*,
model_instance: ModelInstance,
) -> Generator[
NodeEventBase,
None,
tuple[
str,
str,
str,
str,
LLMUsage,
str | None,
LLMStructuredOutput | None,
LLMGenerationData | None,
],
]:
"""
Stream events and capture generator return value in one place.
Uses generator delegation so _run stays concise while still emitting events.
"""
clean_text = ""
reasoning_content = ""
generation_reasoning_content = ""
generation_clean_content = ""
usage = LLMUsage.empty_usage()
finish_reason: str | None = None
structured_output: LLMStructuredOutput | None = None
generation_data: LLMGenerationData | None = None
completed = False
while True:
try:
event = next(generator)
except StopIteration as exc:
if isinstance(exc.value, LLMGenerationData):
generation_data = exc.value
break
if completed:
# After completion we still drain to reach StopIteration.value
continue
match event:
case StreamChunkEvent() | ThoughtChunkEvent():
yield event
case ModelInvokeCompletedEvent(
text=text,
usage=usage_event,
finish_reason=finish_reason_event,
reasoning_content=reasoning_event,
structured_output=structured_raw,
):
clean_text = text
usage = usage_event
finish_reason = finish_reason_event
reasoning_content = reasoning_event or ""
generation_reasoning_content = reasoning_content
generation_clean_content = clean_text
if self.node_data.reasoning_format == "tagged":
# Keep tagged text for output; also extract reasoning for generation field
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, reasoning_format="separated"
)
else:
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
clean_text, self.node_data.reasoning_format
)
generation_clean_content = clean_text
structured_output = (
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
)
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
completed = True
case LLMStructuredOutput():
structured_output = event
case _:
continue
return (
clean_text,
reasoning_content,
generation_reasoning_content,
generation_clean_content,
usage,
finish_reason,
structured_output,
generation_data,
)
def _invoke_llm_with_tools(
self,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
files: Sequence["File"],
variable_pool: VariablePool,
node_inputs: dict[str, Any],
process_data: dict[str, Any],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Invoke LLM with tools support (from Agent V2).
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
"""
# Get model features to determine strategy
model_features = self._get_model_features(model_instance)
# Prepare tool instances
tool_instances = self._prepare_tool_instances(variable_pool)
# Prepare prompt files (files that come from prompt variables, not vision files)
prompt_files = self._extract_prompt_files(variable_pool)
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=model_instance,
tools=tool_instances,
files=prompt_files,
max_iterations=self._node_data.max_iterations or 10,
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
)
# Run strategy
outputs = strategy.run(
prompt_messages=list(prompt_messages),
model_parameters=self._node_data.model.completion_params,
stop=list(stop or []),
stream=True,
)
# Process outputs and return generation result
result = yield from self._process_tool_outputs(outputs)
return result
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
"""Get model schema to determine features."""
try:
model_type_instance = model_instance.model_type_instance
model_schema = model_type_instance.get_model_schema(
model_instance.model,
model_instance.credentials,
)
return model_schema.features if model_schema and model_schema.features else []
except Exception:
logger.warning("Failed to get model schema, assuming no special features")
return []
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
"""Prepare tool instances from configuration."""
tool_instances = []
if self._node_data.tools:
for tool in self._node_data.tools:
try:
# Process settings to extract the correct structure
processed_settings = {}
for key, value in tool.settings.items():
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
# Extract the nested value if it has the ToolInput structure
if "type" in value["value"] and "value" in value["value"]:
processed_settings[key] = value["value"]
else:
processed_settings[key] = value
else:
processed_settings[key] = value
# Merge parameters with processed settings (similar to Agent Node logic)
merged_parameters = {**tool.parameters, **processed_settings}
# Create AgentToolEntity from ToolMetadata
agent_tool = AgentToolEntity(
provider_id=tool.provider_name,
provider_type=tool.type,
tool_name=tool.tool_name,
tool_parameters=merged_parameters,
plugin_unique_identifier=tool.plugin_unique_identifier,
credential_id=tool.credential_id,
)
# Get tool runtime from ToolManager
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
app_id=self.app_id,
agent_tool=agent_tool,
invoke_from=self.invoke_from,
variable_pool=variable_pool,
)
# Apply custom description from extra field if available
if tool.extra.get("description") and tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
tool.extra.get("description") or tool_runtime.entity.description.llm
)
tool_instances.append(tool_runtime)
except Exception as e:
logger.warning("Failed to load tool %s: %s", tool, str(e))
continue
return tool_instances
def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]:
"""Extract files from prompt template variables."""
from core.variables import ArrayFileVariable, FileVariable
files: list[File] = []
# Extract variables from prompt template
if isinstance(self._node_data.prompt_template, list):
for message in self._node_data.prompt_template:
if message.text:
parser = VariableTemplateParser(message.text)
variable_selectors = parser.extract_variable_selectors()
for variable_selector in variable_selectors:
variable = variable_pool.get(variable_selector.value_selector)
if isinstance(variable, FileVariable) and variable.value:
files.append(variable.value)
elif isinstance(variable, ArrayFileVariable) and variable.value:
files.extend(variable.value)
return files
@staticmethod
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
"""Convert ToolCallResult into JSON-friendly dict."""
def _file_to_ref(file: File) -> str | None:
# Align with streamed tool result events which carry file IDs
return file.id or file.related_id
files = []
for file in tool_call.files or []:
ref = _file_to_ref(file)
if ref:
files.append(ref)
return {
"id": tool_call.id,
"name": tool_call.name,
"arguments": tool_call.arguments,
"output": tool_call.output,
"files": files,
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
}
def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
if not buffers.pending_thought:
return
trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought)))
buffers.pending_thought.clear()
def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
if not buffers.pending_content:
return
trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content)))
buffers.pending_content.clear()
def _handle_agent_log_output(
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
) -> Generator[NodeEventBase, None, None]:
payload = ToolLogPayload.from_log(output)
agent_log_event = AgentLogEvent(
message_id=output.id,
label=output.label,
node_execution_id=self.id,
parent_id=output.parent_id,
error=output.error,
status=output.status.value,
data=output.data,
metadata={k.value: v for k, v in output.metadata.items()},
node_id=self._node_id,
)
for log in agent_context.agent_logs:
if log.message_id == agent_log_event.message_id:
log.data = agent_log_event.data
log.status = agent_log_event.status
log.error = agent_log_event.error
log.label = agent_log_event.label
log.metadata = agent_log_event.metadata
break
else:
agent_context.agent_logs.append(agent_log_event)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_call_id = payload.tool_call_id
tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else ""
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
tool_call_segment = LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
)
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
is_final=False,
)
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
tool_name = payload.tool_name
tool_output = payload.tool_output
tool_call_id = payload.tool_call_id
tool_files = payload.files if isinstance(payload.files, list) else []
tool_error = payload.tool_error
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
if output.status == AgentLog.LogStatus.ERROR:
tool_error = output.error or payload.tool_error
if not tool_error and payload.meta:
tool_error = payload.meta.get("error")
else:
if payload.meta:
meta_error = payload.meta.get("error")
if meta_error:
tool_error = meta_error
existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id)
tool_call_segment = existing_tool_segment or LLMTraceSegment(
type="tool_call",
text=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
),
)
if existing_tool_segment is None:
trace_state.trace_segments.append(tool_call_segment)
if tool_call_id:
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
if tool_call_segment.tool_call is None:
tool_call_segment.tool_call = ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
)
tool_call_segment.tool_call.output = (
str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None
)
tool_call_segment.tool_call.files = []
tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=result_output or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_output,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
),
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
buffers.current_turn_reasoning.clear()
def _handle_llm_chunk_output(
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
message = output.delta.message
if message and message.content:
chunk_text = message.content
if isinstance(chunk_text, list):
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
else:
chunk_text = str(chunk_text)
for kind, segment in buffers.think_parser.process(chunk_text):
if not segment:
continue
if kind == "thought":
self._flush_content_segment(buffers, trace_state)
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
self._flush_thought_segment(buffers, trace_state)
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if output.delta.usage:
self._accumulate_usage(aggregate.usage, output.delta.usage)
if output.delta.finish_reason:
aggregate.finish_reason = output.delta.finish_reason
def _flush_remaining_stream(
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
) -> Generator[NodeEventBase, None, None]:
for kind, segment in buffers.think_parser.flush():
if not segment:
continue
if kind == "thought":
self._flush_content_segment(buffers, trace_state)
buffers.current_turn_reasoning.append(segment)
buffers.pending_thought.append(segment)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk=segment,
is_final=False,
)
else:
self._flush_thought_segment(buffers, trace_state)
aggregate.text += segment
buffers.pending_content.append(segment)
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk=segment,
is_final=False,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk=segment,
is_final=False,
)
if buffers.current_turn_reasoning:
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
self._flush_thought_segment(buffers, trace_state)
self._flush_content_segment(buffers, trace_state)
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
yield StreamChunkEvent(
selector=[self._node_id, "text"],
chunk="",
is_final=True,
)
yield StreamChunkEvent(
selector=[self._node_id, "generation", "content"],
chunk="",
is_final=True,
)
yield ThoughtChunkEvent(
selector=[self._node_id, "generation", "thought"],
chunk="",
is_final=True,
)
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
def _build_generation_data(
self,
trace_state: TraceState,
agent_context: AgentContext,
aggregate: AggregatedResult,
buffers: StreamBuffers,
) -> LLMGenerationData:
sequence: list[dict[str, Any]] = []
reasoning_index = 0
content_position = 0
tool_call_seen_index: dict[str, int] = {}
for trace_segment in trace_state.trace_segments:
if trace_segment.type == "thought":
sequence.append({"type": "reasoning", "index": reasoning_index})
reasoning_index += 1
elif trace_segment.type == "content":
segment_text = trace_segment.text or ""
start = content_position
end = start + len(segment_text)
sequence.append({"type": "content", "start": start, "end": end})
content_position = end
elif trace_segment.type == "tool_call":
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
if tool_id not in tool_call_seen_index:
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_context.agent_logs:
payload = ToolLogPayload.from_mapping(log.data or {})
tool_call_id = payload.tool_call_id
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = payload.tool_args
log_error = payload.tool_error
log_output = payload.tool_output
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
ToolCallResult(
id=tool_call_id,
name=payload.tool_name,
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
)
)
tool_calls_for_generation.sort(
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
)
return LLMGenerationData(
text=aggregate.text,
reasoning_contents=buffers.reasoning_per_turn,
tool_calls=tool_calls_for_generation,
sequence=sequence,
usage=aggregate.usage,
finish_reason=aggregate.finish_reason,
files=aggregate.files,
trace=trace_state.trace_segments,
)
def _process_tool_outputs(
self,
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
) -> Generator[NodeEventBase, None, LLMGenerationData]:
"""Process strategy outputs and convert to node events."""
state = ToolOutputState()
try:
for output in outputs:
if isinstance(output, AgentLog):
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
else:
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
except StopIteration as exception:
if isinstance(getattr(exception, "value", None), AgentResult):
state.agent.agent_result = exception.value
if state.agent.agent_result:
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
state.aggregate.files = state.agent.agent_result.files
if state.agent.agent_result.usage:
state.aggregate.usage = state.agent.agent_result.usage
if state.agent.agent_result.finish_reason:
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
yield from self._close_streams()
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
total_usage.prompt_tokens += delta_usage.prompt_tokens
total_usage.completion_tokens += delta_usage.completion_tokens
total_usage.total_tokens += delta_usage.total_tokens
total_usage.prompt_price += delta_usage.prompt_price
total_usage.completion_price += delta_usage.completion_price
total_usage.total_price += delta_usage.total_price
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole