mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 13:16:16 +08:00
Merge branch 'feat/agent-node-v2' into deploy/dev
This commit is contained in:
@ -1,11 +1,16 @@
|
||||
from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"ToolCall",
|
||||
"ToolCallResult",
|
||||
"ToolResult",
|
||||
"ToolResultStatus",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
]
|
||||
|
||||
33
api/core/workflow/entities/tool_entities.py
Normal file
33
api/core/workflow/entities/tool_entities.py
Normal file
@ -0,0 +1,33 @@
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
@ -247,6 +247,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
|
||||
LLM_TRACE = "llm_trace"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.enums import NodeExecutionType, NodeState
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.template import TextSegment, VariableSegment
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -321,11 +327,24 @@ class ResponseStreamCoordinator:
|
||||
selector: Sequence[str],
|
||||
chunk: str,
|
||||
is_final: bool = False,
|
||||
chunk_type: ChunkType = ChunkType.TEXT,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_result: ToolResult | None = None,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Create a stream chunk event with consistent structure.
|
||||
|
||||
For selectors with special prefixes (sys, env, conversation), we use the
|
||||
active response node's information since these are not actual node IDs.
|
||||
|
||||
Args:
|
||||
node_id: The node ID to attribute the event to
|
||||
execution_id: The execution ID for this node
|
||||
selector: The variable selector
|
||||
chunk: The chunk content
|
||||
is_final: Whether this is the final chunk
|
||||
chunk_type: The semantic type of the chunk being streamed
|
||||
tool_call: Structured data for tool_call chunks
|
||||
tool_result: Structured data for tool_result chunks
|
||||
"""
|
||||
# Check if this is a special selector that doesn't correspond to a node
|
||||
if selector and selector[0] not in self._graph.nodes and self._active_session:
|
||||
@ -338,6 +357,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Standard case: selector refers to an actual node
|
||||
@ -349,6 +371,9 @@ class ResponseStreamCoordinator:
|
||||
selector=selector,
|
||||
chunk=chunk,
|
||||
is_final=is_final,
|
||||
chunk_type=chunk_type,
|
||||
tool_call=tool_call,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
|
||||
@ -356,6 +381,8 @@ class ResponseStreamCoordinator:
|
||||
|
||||
Handles both regular node selectors and special system selectors (sys, env, conversation).
|
||||
For special selectors, we attribute the output to the active response node.
|
||||
|
||||
For object-type variables, automatically streams all child fields that have stream events.
|
||||
"""
|
||||
events: list[NodeRunStreamChunkEvent] = []
|
||||
source_selector_prefix = segment.selector[0] if segment.selector else ""
|
||||
@ -364,60 +391,81 @@ class ResponseStreamCoordinator:
|
||||
# Determine which node to attribute the output to
|
||||
# For special selectors (sys, env, conversation), use the active response node
|
||||
# For regular selectors, use the source node
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
# Special selector - use active response node
|
||||
output_node_id = self._active_session.node_id
|
||||
else:
|
||||
# Regular node selector
|
||||
output_node_id = source_selector_prefix
|
||||
active_session = self._active_session
|
||||
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
|
||||
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
|
||||
execution_id = self._get_or_create_execution_id(output_node_id)
|
||||
|
||||
# Stream all available chunks
|
||||
while self._has_unread_stream(segment.selector):
|
||||
if event := self._pop_stream_chunk(segment.selector):
|
||||
# For special selectors, we need to update the event to use
|
||||
# the active response node's information
|
||||
if self._active_session and source_selector_prefix not in self._graph.nodes:
|
||||
response_node = self._graph.nodes[self._active_session.node_id]
|
||||
# Create a new event with the response node's information
|
||||
# but keep the original selector
|
||||
updated_event = NodeRunStreamChunkEvent(
|
||||
id=execution_id,
|
||||
node_id=response_node.id,
|
||||
node_type=response_node.node_type,
|
||||
selector=event.selector, # Keep original selector
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
)
|
||||
events.append(updated_event)
|
||||
else:
|
||||
# Regular node selector - use event as is
|
||||
events.append(event)
|
||||
# Check if there's a direct stream for this selector
|
||||
has_direct_stream = (
|
||||
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
|
||||
)
|
||||
|
||||
# Check if this is the last chunk by looking ahead
|
||||
stream_closed = self._is_stream_closed(segment.selector)
|
||||
# Check if stream is closed to determine if segment is complete
|
||||
if stream_closed:
|
||||
is_complete = True
|
||||
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
|
||||
|
||||
elif value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
if stream_targets:
|
||||
all_complete = True
|
||||
|
||||
for target_selector in stream_targets:
|
||||
while self._has_unread_stream(target_selector):
|
||||
if event := self._pop_stream_chunk(target_selector):
|
||||
events.append(
|
||||
self._rewrite_stream_event(
|
||||
event=event,
|
||||
output_node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
special_selector=bool(special_selector),
|
||||
)
|
||||
)
|
||||
|
||||
if not self._is_stream_closed(target_selector):
|
||||
all_complete = False
|
||||
|
||||
is_complete = all_complete
|
||||
|
||||
# Fallback: check if scalar value exists in variable pool
|
||||
if not is_complete and not has_direct_stream:
|
||||
if value := self._variable_pool.get(segment.selector):
|
||||
# Process scalar value
|
||||
is_last_segment = bool(
|
||||
self._active_session
|
||||
and self._active_session.index == len(self._active_session.template.segments) - 1
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
events.append(
|
||||
self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=segment.selector,
|
||||
chunk=value.markdown,
|
||||
is_final=is_last_segment,
|
||||
)
|
||||
)
|
||||
is_complete = True
|
||||
|
||||
return events, is_complete
|
||||
|
||||
def _rewrite_stream_event(
|
||||
self,
|
||||
event: NodeRunStreamChunkEvent,
|
||||
output_node_id: str,
|
||||
execution_id: str,
|
||||
special_selector: bool,
|
||||
) -> NodeRunStreamChunkEvent:
|
||||
"""Rewrite event to attribute to active response node when selector is special."""
|
||||
if not special_selector:
|
||||
return event
|
||||
|
||||
return self._create_stream_chunk_event(
|
||||
node_id=output_node_id,
|
||||
execution_id=execution_id,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=event.chunk_type,
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
|
||||
"""Process a text segment. Returns (events, is_complete)."""
|
||||
assert self._active_session is not None
|
||||
@ -513,6 +561,36 @@ class ResponseStreamCoordinator:
|
||||
|
||||
# ============= Internal Stream Management Methods =============
|
||||
|
||||
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
|
||||
"""Find all child stream selectors that are descendants of the parent selector.
|
||||
|
||||
For example, if parent_selector is ['llm', 'generation'], this will find:
|
||||
- ['llm', 'generation', 'content']
|
||||
- ['llm', 'generation', 'tool_calls']
|
||||
- ['llm', 'generation', 'tool_results']
|
||||
- ['llm', 'generation', 'thought']
|
||||
|
||||
Args:
|
||||
parent_selector: The parent selector to search for children
|
||||
|
||||
Returns:
|
||||
List of child selector tuples found in stream buffers or closed streams
|
||||
"""
|
||||
parent_key = tuple(parent_selector)
|
||||
parent_len = len(parent_key)
|
||||
child_streams: set[tuple[str, ...]] = set()
|
||||
|
||||
# Search in both active buffers and closed streams
|
||||
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
|
||||
|
||||
for selector_key in all_selectors:
|
||||
# Check if this selector is a direct child of the parent
|
||||
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
|
||||
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
|
||||
child_streams.add(selector_key)
|
||||
|
||||
return sorted(child_streams)
|
||||
|
||||
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
|
||||
"""
|
||||
Append a stream chunk to the internal buffer.
|
||||
|
||||
@ -36,6 +36,7 @@ from .loop import (
|
||||
|
||||
# Node events
|
||||
from .node import (
|
||||
ChunkType,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
@ -44,10 +45,13 @@ from .node import (
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseGraphEvent",
|
||||
"ChunkType",
|
||||
"GraphEngineEvent",
|
||||
"GraphNodeEventBase",
|
||||
"GraphRunAbortedEvent",
|
||||
@ -73,4 +77,6 @@ __all__ = [
|
||||
"NodeRunStartedEvent",
|
||||
"NodeRunStreamChunkEvent",
|
||||
"NodeRunSucceededEvent",
|
||||
"ToolCall",
|
||||
"ToolResult",
|
||||
]
|
||||
|
||||
@ -1,10 +1,11 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
from .base import GraphNodeEventBase
|
||||
@ -21,13 +22,37 @@ class NodeRunStartedEvent(GraphNodeEventBase):
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(GraphNodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Stream chunk event for workflow node execution."""
|
||||
|
||||
# Base fields
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
|
||||
# Tool call fields (when chunk_type == TOOL_CALL)
|
||||
tool_call: ToolCall | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_call chunks",
|
||||
)
|
||||
|
||||
# Tool result fields (when chunk_type == TOOL_RESULT)
|
||||
tool_result: ToolResult | None = Field(
|
||||
default=None,
|
||||
description="structured payload for tool_result chunks",
|
||||
)
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):
|
||||
|
||||
@ -13,16 +13,21 @@ from .loop import (
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
ChunkType,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
RunRetryEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"ChunkType",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
@ -39,4 +44,7 @@ __all__ = [
|
||||
"RunRetryEvent",
|
||||
"StreamChunkEvent",
|
||||
"StreamCompletedEvent",
|
||||
"ThoughtChunkEvent",
|
||||
"ToolCallChunkEvent",
|
||||
"ToolResultChunkEvent",
|
||||
]
|
||||
|
||||
@ -1,11 +1,13 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import ToolCall, ToolResult
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
|
||||
@ -32,13 +34,46 @@ class RunRetryEvent(NodeEventBase):
|
||||
start_at: datetime = Field(..., description="Retry start time")
|
||||
|
||||
|
||||
class ChunkType(StrEnum):
|
||||
"""Stream chunk type for LLM-related events."""
|
||||
|
||||
TEXT = "text" # Normal text streaming
|
||||
TOOL_CALL = "tool_call" # Tool call arguments streaming
|
||||
TOOL_RESULT = "tool_result" # Tool execution result
|
||||
THOUGHT = "thought" # Agent thinking process (ReAct)
|
||||
|
||||
|
||||
class StreamChunkEvent(NodeEventBase):
|
||||
# Spec-compliant fields
|
||||
"""Base stream chunk event - normal text streaming output."""
|
||||
|
||||
selector: Sequence[str] = Field(
|
||||
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
|
||||
)
|
||||
chunk: str = Field(..., description="the actual chunk content")
|
||||
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
|
||||
|
||||
|
||||
class ToolCallChunkEvent(StreamChunkEvent):
|
||||
"""Tool call streaming event - tool call arguments streaming output."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
|
||||
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
|
||||
|
||||
|
||||
class ToolResultChunkEvent(StreamChunkEvent):
|
||||
"""Tool result event - tool execution result."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
|
||||
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
|
||||
|
||||
|
||||
class ThoughtChunkEvent(StreamChunkEvent):
|
||||
"""Agent thought streaming event - Agent thinking process (ReAct)."""
|
||||
|
||||
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
|
||||
|
||||
|
||||
class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
@ -46,6 +46,9 @@ from core.workflow.node_events import (
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@ -543,6 +546,8 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
@ -550,6 +555,65 @@ class Node(Generic[NodeDataT]):
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType(event.chunk_type.value),
|
||||
tool_call=event.tool_call,
|
||||
tool_result=event.tool_result,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolCallChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call=event.tool_call,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.entities import ToolResult, ToolResultStatus
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
tool_result = event.tool_result
|
||||
status: ToolResultStatus = (
|
||||
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
|
||||
)
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_result=ToolResult(
|
||||
id=tool_result.id if tool_result else None,
|
||||
name=tool_result.name if tool_result else None,
|
||||
output=tool_result.output if tool_result else None,
|
||||
files=tool_result.files if tool_result else [],
|
||||
status=status,
|
||||
),
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: ThoughtChunkEvent) -> NodeRunStreamChunkEvent:
|
||||
from core.workflow.graph_events import ChunkType
|
||||
|
||||
return NodeRunStreamChunkEvent(
|
||||
id=self._node_execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
selector=event.selector,
|
||||
chunk=event.chunk,
|
||||
is_final=event.is_final,
|
||||
chunk_type=ChunkType.THOUGHT,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
|
||||
@ -3,6 +3,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
from .node import LLMNode
|
||||
@ -13,5 +14,6 @@ __all__ = [
|
||||
"LLMNodeCompletionModelPromptTemplate",
|
||||
"LLMNodeData",
|
||||
"ModelConfig",
|
||||
"ToolMetadata",
|
||||
"VisionConfig",
|
||||
]
|
||||
|
||||
@ -1,10 +1,17 @@
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.entities import ToolCallResult
|
||||
from core.workflow.node_events import AgentLogEvent
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
|
||||
@ -58,6 +65,235 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""
|
||||
Tool metadata for LLM node with tool support.
|
||||
|
||||
Defines the essential fields needed for tool configuration,
|
||||
particularly the 'type' field to identify tool provider type.
|
||||
"""
|
||||
|
||||
# Core fields
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
|
||||
# Optional fields
|
||||
plugin_unique_identifier: str | None = Field(None, description="Plugin unique identifier for plugin tools")
|
||||
credential_id: str | None = Field(None, description="Credential ID for tools requiring authentication")
|
||||
|
||||
# Configuration fields
|
||||
parameters: dict[str, Any] = Field(default_factory=dict, description="Tool parameters")
|
||||
settings: dict[str, Any] = Field(default_factory=dict, description="Tool settings configuration")
|
||||
extra: dict[str, Any] = Field(default_factory=dict, description="Extra tool configuration like custom description")
|
||||
|
||||
|
||||
class LLMTraceSegment(BaseModel):
|
||||
"""
|
||||
Streaming trace segment for LLM tool-enabled runs.
|
||||
|
||||
Order is preserved for replay. Tool calls are single entries containing both
|
||||
arguments and results.
|
||||
"""
|
||||
|
||||
type: Literal["thought", "content", "tool_call"]
|
||||
|
||||
# Common optional fields
|
||||
text: str | None = Field(None, description="Text chunk for thought/content")
|
||||
|
||||
# Tool call fields (combined start + result)
|
||||
tool_call: ToolCallResult | None = Field(
|
||||
default=None,
|
||||
description="Combined tool call arguments and result for this segment",
|
||||
)
|
||||
|
||||
|
||||
class LLMGenerationData(BaseModel):
|
||||
"""Generation data from LLM invocation with tools.
|
||||
|
||||
For multi-turn tool calls like: thought1 -> text1 -> tool_call1 -> thought2 -> text2 -> tool_call2
|
||||
- reasoning_contents: [thought1, thought2, ...] - one element per turn
|
||||
- tool_calls: [{id, name, arguments, result}, ...] - all tool calls with results
|
||||
"""
|
||||
|
||||
text: str = Field(..., description="Accumulated text content from all turns")
|
||||
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
|
||||
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
|
||||
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
|
||||
usage: LLMUsage = Field(..., description="LLM usage statistics")
|
||||
finish_reason: str | None = Field(None, description="Finish reason from LLM")
|
||||
files: list[File] = Field(default_factory=list, description="Generated files")
|
||||
trace: list[LLMTraceSegment] = Field(default_factory=list, description="Streaming trace in emitted order")
|
||||
|
||||
|
||||
class ThinkTagStreamParser:
|
||||
"""Lightweight state machine to split streaming chunks by <think> tags."""
|
||||
|
||||
_START_PATTERN = re.compile(r"<think(?:\s[^>]*)?>", re.IGNORECASE)
|
||||
_END_PATTERN = re.compile(r"</think>", re.IGNORECASE)
|
||||
_START_PREFIX = "<think"
|
||||
_END_PREFIX = "</think"
|
||||
|
||||
def __init__(self):
|
||||
self._buffer = ""
|
||||
self._in_think = False
|
||||
|
||||
@staticmethod
|
||||
def _suffix_prefix_len(text: str, prefix: str) -> int:
|
||||
"""Return length of the longest suffix of `text` that is a prefix of `prefix`."""
|
||||
max_len = min(len(text), len(prefix) - 1)
|
||||
for i in range(max_len, 0, -1):
|
||||
if text[-i:].lower() == prefix[:i].lower():
|
||||
return i
|
||||
return 0
|
||||
|
||||
def process(self, chunk: str) -> list[tuple[str, str]]:
|
||||
"""
|
||||
Split incoming chunk into ('thought' | 'text', content) tuples.
|
||||
Content excludes the <think> tags themselves and handles split tags across chunks.
|
||||
"""
|
||||
parts: list[tuple[str, str]] = []
|
||||
self._buffer += chunk
|
||||
|
||||
while self._buffer:
|
||||
if self._in_think:
|
||||
end_match = self._END_PATTERN.search(self._buffer)
|
||||
if end_match:
|
||||
thought_text = self._buffer[: end_match.start()]
|
||||
if thought_text:
|
||||
parts.append(("thought", thought_text))
|
||||
self._buffer = self._buffer[end_match.end() :]
|
||||
self._in_think = False
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._END_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("thought", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
start_match = self._START_PATTERN.search(self._buffer)
|
||||
if start_match:
|
||||
prefix = self._buffer[: start_match.start()]
|
||||
if prefix:
|
||||
parts.append(("text", prefix))
|
||||
self._buffer = self._buffer[start_match.end() :]
|
||||
self._in_think = True
|
||||
continue
|
||||
|
||||
hold_len = self._suffix_prefix_len(self._buffer, self._START_PREFIX)
|
||||
emit = self._buffer[: len(self._buffer) - hold_len]
|
||||
if emit:
|
||||
parts.append(("text", emit))
|
||||
self._buffer = self._buffer[-hold_len:] if hold_len > 0 else ""
|
||||
break
|
||||
|
||||
cleaned_parts: list[tuple[str, str]] = []
|
||||
for kind, content in parts:
|
||||
# Extra safeguard: strip any stray tags that slipped through.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
if content:
|
||||
cleaned_parts.append((kind, content))
|
||||
|
||||
return cleaned_parts
|
||||
|
||||
def flush(self) -> list[tuple[str, str]]:
|
||||
"""Flush remaining buffer when the stream ends."""
|
||||
if not self._buffer:
|
||||
return []
|
||||
kind = "thought" if self._in_think else "text"
|
||||
content = self._buffer
|
||||
# Drop dangling partial tags instead of emitting them
|
||||
if content.lower().startswith(self._START_PREFIX) or content.lower().startswith(self._END_PREFIX):
|
||||
content = ""
|
||||
self._buffer = ""
|
||||
if not content:
|
||||
return []
|
||||
# Strip any complete tags that might still be present.
|
||||
content = self._START_PATTERN.sub("", content)
|
||||
content = self._END_PATTERN.sub("", content)
|
||||
return [(kind, content)] if content else []
|
||||
|
||||
|
||||
class StreamBuffers(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
think_parser: ThinkTagStreamParser = Field(default_factory=ThinkTagStreamParser)
|
||||
pending_thought: list[str] = Field(default_factory=list)
|
||||
pending_content: list[str] = Field(default_factory=list)
|
||||
current_turn_reasoning: list[str] = Field(default_factory=list)
|
||||
reasoning_per_turn: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TraceState(BaseModel):
|
||||
trace_segments: list[LLMTraceSegment] = Field(default_factory=list)
|
||||
tool_trace_map: dict[str, LLMTraceSegment] = Field(default_factory=dict)
|
||||
tool_call_index_map: dict[str, int] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AggregatedResult(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
text: str = ""
|
||||
files: list[File] = Field(default_factory=list)
|
||||
usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage)
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
class AgentContext(BaseModel):
|
||||
agent_logs: list[AgentLogEvent] = Field(default_factory=list)
|
||||
agent_result: AgentResult | None = None
|
||||
|
||||
|
||||
class ToolOutputState(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
stream: StreamBuffers = Field(default_factory=StreamBuffers)
|
||||
trace: TraceState = Field(default_factory=TraceState)
|
||||
aggregate: AggregatedResult = Field(default_factory=AggregatedResult)
|
||||
agent: AgentContext = Field(default_factory=AgentContext)
|
||||
|
||||
|
||||
class ToolLogPayload(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
tool_name: str = ""
|
||||
tool_call_id: str = ""
|
||||
tool_args: dict[str, Any] = Field(default_factory=dict)
|
||||
tool_output: Any = None
|
||||
tool_error: Any = None
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
meta: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_log(cls, log: AgentLog) -> "ToolLogPayload":
|
||||
data = log.data or {}
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_mapping(cls, data: Mapping[str, Any]) -> "ToolLogPayload":
|
||||
return cls(
|
||||
tool_name=data.get("tool_name", ""),
|
||||
tool_call_id=data.get("tool_call_id", ""),
|
||||
tool_args=data.get("tool_args") or {},
|
||||
tool_output=data.get("output"),
|
||||
tool_error=data.get("error"),
|
||||
files=data.get("files") or [],
|
||||
meta=data.get("meta") or {},
|
||||
)
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
@ -86,6 +322,10 @@ class LLMNodeData(BaseNodeData):
|
||||
),
|
||||
)
|
||||
|
||||
# Tool support
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int | None = Field(default=None, description="Maximum number of iterations for the LLM node")
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
|
||||
@ -9,6 +9,8 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentToolEntity, ExecutionContext
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
@ -46,7 +48,9 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@ -56,7 +60,8 @@ from core.variables import (
|
||||
StringSegment,
|
||||
)
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.entities.tool_entities import ToolCallResult
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@ -64,12 +69,16 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
ModelInvokeCompletedEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
RunRetrieverResourceEvent,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
ThoughtChunkEvent,
|
||||
ToolCallChunkEvent,
|
||||
ToolResultChunkEvent,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@ -81,10 +90,19 @@ from models.model import UploadFile
|
||||
|
||||
from . import llm_utils
|
||||
from .entities import (
|
||||
AgentContext,
|
||||
AggregatedResult,
|
||||
LLMGenerationData,
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
LLMTraceSegment,
|
||||
ModelConfig,
|
||||
StreamBuffers,
|
||||
ThinkTagStreamParser,
|
||||
ToolLogPayload,
|
||||
ToolOutputState,
|
||||
TraceState,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@ -149,11 +167,11 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def _run(self) -> Generator:
|
||||
node_inputs: dict[str, Any] = {}
|
||||
process_data: dict[str, Any] = {}
|
||||
result_text = ""
|
||||
clean_text = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
reasoning_content = None
|
||||
reasoning_content = "" # Initialize as empty string for consistency
|
||||
clean_text = "" # Initialize clean_text to avoid UnboundLocalError
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
@ -234,55 +252,58 @@ class LLMNode(Node[LLMNodeData]):
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self.node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=self.node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
)
|
||||
|
||||
# Variables for outputs
|
||||
generation_data: LLMGenerationData | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, StreamChunkEvent):
|
||||
yield event
|
||||
elif isinstance(event, ModelInvokeCompletedEvent):
|
||||
# Raw text
|
||||
result_text = event.text
|
||||
usage = event.usage
|
||||
finish_reason = event.finish_reason
|
||||
reasoning_content = event.reasoning_content or ""
|
||||
# Check if tools are configured
|
||||
if self.tool_call_enabled:
|
||||
# Use tool-enabled invocation (Agent V2 style)
|
||||
generator = self._invoke_llm_with_tools(
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
files=files,
|
||||
variable_pool=variable_pool,
|
||||
node_inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
)
|
||||
else:
|
||||
# Use traditional LLM invocation
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=self._node_data.model,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self._node_data.structured_output_enabled,
|
||||
structured_output=self._node_data.structured_output,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self._node_data.reasoning_format,
|
||||
)
|
||||
|
||||
# For downstream nodes, determine clean text based on reasoning_format
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep <think> tags for backward compatibility
|
||||
clean_text = result_text
|
||||
else:
|
||||
# Extract clean text from <think> tags
|
||||
clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format)
|
||||
(
|
||||
clean_text,
|
||||
reasoning_content,
|
||||
generation_reasoning_content,
|
||||
generation_clean_content,
|
||||
usage,
|
||||
finish_reason,
|
||||
structured_output,
|
||||
generation_data,
|
||||
) = yield from self._stream_llm_events(generator, model_instance=model_instance)
|
||||
|
||||
# Process structured output if available from the event.
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=event.structured_output)
|
||||
if event.structured_output
|
||||
else None
|
||||
)
|
||||
|
||||
# deduct quota
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
break
|
||||
elif isinstance(event, LLMStructuredOutput):
|
||||
structured_output = event
|
||||
# Extract variables from generation_data if available
|
||||
if generation_data:
|
||||
clean_text = generation_data.text
|
||||
reasoning_content = ""
|
||||
usage = generation_data.usage
|
||||
finish_reason = generation_data.finish_reason
|
||||
|
||||
# Unified process_data building
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
@ -293,24 +314,88 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
if self.tool_call_enabled and self._node_data.tools:
|
||||
process_data["tools"] = [
|
||||
{
|
||||
"type": tool.type.value if hasattr(tool.type, "value") else tool.type,
|
||||
"provider_name": tool.provider_name,
|
||||
"tool_name": tool.tool_name,
|
||||
}
|
||||
for tool in self._node_data.tools
|
||||
if tool.enabled
|
||||
]
|
||||
|
||||
# Unified outputs building
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
# Build generation field
|
||||
if generation_data:
|
||||
# Use generation_data from tool invocation (supports multi-turn)
|
||||
generation = {
|
||||
"content": generation_data.text,
|
||||
"reasoning_content": generation_data.reasoning_contents, # [thought1, thought2, ...]
|
||||
"tool_calls": [self._serialize_tool_call(item) for item in generation_data.tool_calls],
|
||||
"sequence": generation_data.sequence,
|
||||
}
|
||||
files_to_output = generation_data.files
|
||||
else:
|
||||
# Traditional LLM invocation
|
||||
generation_reasoning = generation_reasoning_content or reasoning_content
|
||||
generation_content = generation_clean_content or clean_text
|
||||
sequence: list[dict[str, Any]] = []
|
||||
if generation_reasoning:
|
||||
sequence = [
|
||||
{"type": "reasoning", "index": 0},
|
||||
{"type": "content", "start": 0, "end": len(generation_content)},
|
||||
]
|
||||
generation = {
|
||||
"content": generation_content,
|
||||
"reasoning_content": [generation_reasoning] if generation_reasoning else [],
|
||||
"tool_calls": [],
|
||||
"sequence": sequence,
|
||||
}
|
||||
files_to_output = self._file_outputs
|
||||
|
||||
outputs["generation"] = generation
|
||||
if files_to_output:
|
||||
outputs["files"] = ArrayFileSegment(value=files_to_output)
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
if not self.tool_call_enabled:
|
||||
# For tool calls, final events are already sent in _process_tool_outputs
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = {
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
}
|
||||
|
||||
if generation_data and generation_data.trace:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.LLM_TRACE] = [
|
||||
segment.model_dump() for segment in generation_data.trace
|
||||
]
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
@ -318,11 +403,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
inputs=node_inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
metadata={
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
|
||||
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
|
||||
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
|
||||
},
|
||||
metadata=metadata,
|
||||
llm_usage=usage,
|
||||
)
|
||||
)
|
||||
@ -444,6 +525,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason = None
|
||||
full_text_buffer = io.StringIO()
|
||||
think_parser = ThinkTagStreamParser()
|
||||
reasoning_chunks: list[str] = []
|
||||
|
||||
# Initialize streaming metrics tracking
|
||||
start_time = request_start_time if request_start_time is not None else time.perf_counter()
|
||||
@ -472,12 +555,32 @@ class LLMNode(Node[LLMNodeData]):
|
||||
has_content = True
|
||||
|
||||
full_text_buffer.write(text_part)
|
||||
# Text output: always forward raw chunk (keep <think> tags intact)
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=text_part,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Generation output: split out thoughts, forward only non-thought content chunks
|
||||
for kind, segment in think_parser.process(text_part):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Update the whole metadata
|
||||
if not model and result.model:
|
||||
model = result.model
|
||||
@ -492,16 +595,35 @@ class LLMNode(Node[LLMNodeData]):
|
||||
except OutputParserError as e:
|
||||
raise LLMNodeError(f"Failed to parse structured output: {e}")
|
||||
|
||||
for kind, segment in think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
reasoning_chunks.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
# Extract reasoning content from <think> tags in the main text
|
||||
full_text = full_text_buffer.getvalue()
|
||||
|
||||
if reasoning_format == "tagged":
|
||||
# Keep <think> tags in text for backward compatibility
|
||||
clean_text = full_text
|
||||
reasoning_content = ""
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
else:
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
if reasoning_chunks and not reasoning_content:
|
||||
reasoning_content = "".join(reasoning_chunks)
|
||||
|
||||
# Calculate streaming metrics
|
||||
end_time = time.perf_counter()
|
||||
@ -1266,6 +1388,635 @@ class LLMNode(Node[LLMNodeData]):
|
||||
def retry(self) -> bool:
|
||||
return self.node_data.retry_config.retry_enabled
|
||||
|
||||
@property
|
||||
def tool_call_enabled(self) -> bool:
|
||||
return (
|
||||
self.node_data.tools is not None
|
||||
and len(self.node_data.tools) > 0
|
||||
and all(tool.enabled for tool in self.node_data.tools)
|
||||
)
|
||||
|
||||
def _stream_llm_events(
|
||||
self,
|
||||
generator: Generator[NodeEventBase | LLMStructuredOutput, None, LLMGenerationData | None],
|
||||
*,
|
||||
model_instance: ModelInstance,
|
||||
) -> Generator[
|
||||
NodeEventBase,
|
||||
None,
|
||||
tuple[
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
str,
|
||||
LLMUsage,
|
||||
str | None,
|
||||
LLMStructuredOutput | None,
|
||||
LLMGenerationData | None,
|
||||
],
|
||||
]:
|
||||
"""
|
||||
Stream events and capture generator return value in one place.
|
||||
Uses generator delegation so _run stays concise while still emitting events.
|
||||
"""
|
||||
clean_text = ""
|
||||
reasoning_content = ""
|
||||
generation_reasoning_content = ""
|
||||
generation_clean_content = ""
|
||||
usage = LLMUsage.empty_usage()
|
||||
finish_reason: str | None = None
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
generation_data: LLMGenerationData | None = None
|
||||
completed = False
|
||||
|
||||
while True:
|
||||
try:
|
||||
event = next(generator)
|
||||
except StopIteration as exc:
|
||||
if isinstance(exc.value, LLMGenerationData):
|
||||
generation_data = exc.value
|
||||
break
|
||||
|
||||
if completed:
|
||||
# After completion we still drain to reach StopIteration.value
|
||||
continue
|
||||
|
||||
match event:
|
||||
case StreamChunkEvent() | ThoughtChunkEvent():
|
||||
yield event
|
||||
|
||||
case ModelInvokeCompletedEvent(
|
||||
text=text,
|
||||
usage=usage_event,
|
||||
finish_reason=finish_reason_event,
|
||||
reasoning_content=reasoning_event,
|
||||
structured_output=structured_raw,
|
||||
):
|
||||
clean_text = text
|
||||
usage = usage_event
|
||||
finish_reason = finish_reason_event
|
||||
reasoning_content = reasoning_event or ""
|
||||
generation_reasoning_content = reasoning_content
|
||||
generation_clean_content = clean_text
|
||||
|
||||
if self.node_data.reasoning_format == "tagged":
|
||||
# Keep tagged text for output; also extract reasoning for generation field
|
||||
generation_clean_content, generation_reasoning_content = LLMNode._split_reasoning(
|
||||
clean_text, reasoning_format="separated"
|
||||
)
|
||||
else:
|
||||
clean_text, generation_reasoning_content = LLMNode._split_reasoning(
|
||||
clean_text, self.node_data.reasoning_format
|
||||
)
|
||||
generation_clean_content = clean_text
|
||||
|
||||
structured_output = (
|
||||
LLMStructuredOutput(structured_output=structured_raw) if structured_raw else None
|
||||
)
|
||||
|
||||
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
|
||||
completed = True
|
||||
|
||||
case LLMStructuredOutput():
|
||||
structured_output = event
|
||||
|
||||
case _:
|
||||
continue
|
||||
|
||||
return (
|
||||
clean_text,
|
||||
reasoning_content,
|
||||
generation_reasoning_content,
|
||||
generation_clean_content,
|
||||
usage,
|
||||
finish_reason,
|
||||
structured_output,
|
||||
generation_data,
|
||||
)
|
||||
|
||||
def _invoke_llm_with_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
stop: Sequence[str] | None,
|
||||
files: Sequence["File"],
|
||||
variable_pool: VariablePool,
|
||||
node_inputs: dict[str, Any],
|
||||
process_data: dict[str, Any],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Invoke LLM with tools support (from Agent V2).
|
||||
|
||||
Returns LLMGenerationData with text, reasoning_contents, tool_calls, usage, finish_reason, files
|
||||
"""
|
||||
# Get model features to determine strategy
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
# Prepare tool instances
|
||||
tool_instances = self._prepare_tool_instances(variable_pool)
|
||||
|
||||
# Prepare prompt files (files that come from prompt variables, not vision files)
|
||||
prompt_files = self._extract_prompt_files(variable_pool)
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=prompt_files,
|
||||
max_iterations=self._node_data.max_iterations or 10,
|
||||
context=ExecutionContext(user_id=self.user_id, app_id=self.app_id, tenant_id=self.tenant_id),
|
||||
)
|
||||
|
||||
# Run strategy
|
||||
outputs = strategy.run(
|
||||
prompt_messages=list(prompt_messages),
|
||||
model_parameters=self._node_data.model.completion_params,
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Process outputs and return generation result
|
||||
result = yield from self._process_tool_outputs(outputs)
|
||||
return result
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
"""Get model schema to determine features."""
|
||||
try:
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model_instance.model,
|
||||
model_instance.credentials,
|
||||
)
|
||||
return model_schema.features if model_schema and model_schema.features else []
|
||||
except Exception:
|
||||
logger.warning("Failed to get model schema, assuming no special features")
|
||||
return []
|
||||
|
||||
def _prepare_tool_instances(self, variable_pool: VariablePool) -> list[Tool]:
|
||||
"""Prepare tool instances from configuration."""
|
||||
tool_instances = []
|
||||
|
||||
if self._node_data.tools:
|
||||
for tool in self._node_data.tools:
|
||||
try:
|
||||
# Process settings to extract the correct structure
|
||||
processed_settings = {}
|
||||
for key, value in tool.settings.items():
|
||||
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
|
||||
# Extract the nested value if it has the ToolInput structure
|
||||
if "type" in value["value"] and "value" in value["value"]:
|
||||
processed_settings[key] = value["value"]
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
|
||||
# Merge parameters with processed settings (similar to Agent Node logic)
|
||||
merged_parameters = {**tool.parameters, **processed_settings}
|
||||
|
||||
# Create AgentToolEntity from ToolMetadata
|
||||
agent_tool = AgentToolEntity(
|
||||
provider_id=tool.provider_name,
|
||||
provider_type=tool.type,
|
||||
tool_name=tool.tool_name,
|
||||
tool_parameters=merged_parameters,
|
||||
plugin_unique_identifier=tool.plugin_unique_identifier,
|
||||
credential_id=tool.credential_id,
|
||||
)
|
||||
|
||||
# Get tool runtime from ToolManager
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
agent_tool=agent_tool,
|
||||
invoke_from=self.invoke_from,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
# Apply custom description from extra field if available
|
||||
if tool.extra.get("description") and tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
tool.extra.get("description") or tool_runtime.entity.description.llm
|
||||
)
|
||||
|
||||
tool_instances.append(tool_runtime)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load tool %s: %s", tool, str(e))
|
||||
continue
|
||||
|
||||
return tool_instances
|
||||
|
||||
def _extract_prompt_files(self, variable_pool: VariablePool) -> list["File"]:
|
||||
"""Extract files from prompt template variables."""
|
||||
from core.variables import ArrayFileVariable, FileVariable
|
||||
|
||||
files: list[File] = []
|
||||
|
||||
# Extract variables from prompt template
|
||||
if isinstance(self._node_data.prompt_template, list):
|
||||
for message in self._node_data.prompt_template:
|
||||
if message.text:
|
||||
parser = VariableTemplateParser(message.text)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
|
||||
for variable_selector in variable_selectors:
|
||||
variable = variable_pool.get(variable_selector.value_selector)
|
||||
if isinstance(variable, FileVariable) and variable.value:
|
||||
files.append(variable.value)
|
||||
elif isinstance(variable, ArrayFileVariable) and variable.value:
|
||||
files.extend(variable.value)
|
||||
|
||||
return files
|
||||
|
||||
@staticmethod
|
||||
def _serialize_tool_call(tool_call: ToolCallResult) -> dict[str, Any]:
|
||||
"""Convert ToolCallResult into JSON-friendly dict."""
|
||||
|
||||
def _file_to_ref(file: File) -> str | None:
|
||||
# Align with streamed tool result events which carry file IDs
|
||||
return file.id or file.related_id
|
||||
|
||||
files = []
|
||||
for file in tool_call.files or []:
|
||||
ref = _file_to_ref(file)
|
||||
if ref:
|
||||
files.append(ref)
|
||||
|
||||
return {
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.name,
|
||||
"arguments": tool_call.arguments,
|
||||
"output": tool_call.output,
|
||||
"files": files,
|
||||
"status": tool_call.status.value if hasattr(tool_call.status, "value") else tool_call.status,
|
||||
}
|
||||
|
||||
def _flush_thought_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_thought:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="thought", text="".join(buffers.pending_thought)))
|
||||
buffers.pending_thought.clear()
|
||||
|
||||
def _flush_content_segment(self, buffers: StreamBuffers, trace_state: TraceState) -> None:
|
||||
if not buffers.pending_content:
|
||||
return
|
||||
trace_state.trace_segments.append(LLMTraceSegment(type="content", text="".join(buffers.pending_content)))
|
||||
buffers.pending_content.clear()
|
||||
|
||||
def _handle_agent_log_output(
|
||||
self, output: AgentLog, buffers: StreamBuffers, trace_state: TraceState, agent_context: AgentContext
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
payload = ToolLogPayload.from_log(output)
|
||||
agent_log_event = AgentLogEvent(
|
||||
message_id=output.id,
|
||||
label=output.label,
|
||||
node_execution_id=self.id,
|
||||
parent_id=output.parent_id,
|
||||
error=output.error,
|
||||
status=output.status.value,
|
||||
data=output.data,
|
||||
metadata={k.value: v for k, v in output.metadata.items()},
|
||||
node_id=self._node_id,
|
||||
)
|
||||
for log in agent_context.agent_logs:
|
||||
if log.message_id == agent_log_event.message_id:
|
||||
log.data = agent_log_event.data
|
||||
log.status = agent_log_event.status
|
||||
log.error = agent_log_event.error
|
||||
log.label = agent_log_event.label
|
||||
log.metadata = agent_log_event.metadata
|
||||
break
|
||||
else:
|
||||
agent_context.agent_logs.append(agent_log_event)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status == AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_arguments = json.dumps(payload.tool_args) if payload.tool_args else ""
|
||||
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
tool_call_segment = LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
)
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk=tool_arguments,
|
||||
tool_call=ToolCall(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=tool_arguments,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.log_type == AgentLog.LogType.TOOL_CALL and output.status != AgentLog.LogStatus.START:
|
||||
tool_name = payload.tool_name
|
||||
tool_output = payload.tool_output
|
||||
tool_call_id = payload.tool_call_id
|
||||
tool_files = payload.files if isinstance(payload.files, list) else []
|
||||
tool_error = payload.tool_error
|
||||
|
||||
if tool_call_id and tool_call_id not in trace_state.tool_call_index_map:
|
||||
trace_state.tool_call_index_map[tool_call_id] = len(trace_state.tool_call_index_map)
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
if output.status == AgentLog.LogStatus.ERROR:
|
||||
tool_error = output.error or payload.tool_error
|
||||
if not tool_error and payload.meta:
|
||||
tool_error = payload.meta.get("error")
|
||||
else:
|
||||
if payload.meta:
|
||||
meta_error = payload.meta.get("error")
|
||||
if meta_error:
|
||||
tool_error = meta_error
|
||||
|
||||
existing_tool_segment = trace_state.tool_trace_map.get(tool_call_id)
|
||||
tool_call_segment = existing_tool_segment or LLMTraceSegment(
|
||||
type="tool_call",
|
||||
text=None,
|
||||
tool_call=ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
),
|
||||
)
|
||||
if existing_tool_segment is None:
|
||||
trace_state.trace_segments.append(tool_call_segment)
|
||||
if tool_call_id:
|
||||
trace_state.tool_trace_map[tool_call_id] = tool_call_segment
|
||||
|
||||
if tool_call_segment.tool_call is None:
|
||||
tool_call_segment.tool_call = ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
arguments=None,
|
||||
)
|
||||
tool_call_segment.tool_call.output = (
|
||||
str(tool_output) if tool_output is not None else str(tool_error) if tool_error is not None else None
|
||||
)
|
||||
tool_call_segment.tool_call.files = []
|
||||
tool_call_segment.tool_call.status = ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
|
||||
|
||||
result_output = str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
|
||||
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk=result_output or "",
|
||||
tool_result=ToolResult(
|
||||
id=tool_call_id,
|
||||
name=tool_name,
|
||||
output=result_output,
|
||||
files=tool_files,
|
||||
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
buffers.current_turn_reasoning.clear()
|
||||
|
||||
def _handle_llm_chunk_output(
|
||||
self, output: LLMResultChunk, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
message = output.delta.message
|
||||
|
||||
if message and message.content:
|
||||
chunk_text = message.content
|
||||
if isinstance(chunk_text, list):
|
||||
chunk_text = "".join(getattr(content, "data", str(content)) for content in chunk_text)
|
||||
else:
|
||||
chunk_text = str(chunk_text)
|
||||
|
||||
for kind, segment in buffers.think_parser.process(chunk_text):
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
if kind == "thought":
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if output.delta.usage:
|
||||
self._accumulate_usage(aggregate.usage, output.delta.usage)
|
||||
|
||||
if output.delta.finish_reason:
|
||||
aggregate.finish_reason = output.delta.finish_reason
|
||||
|
||||
def _flush_remaining_stream(
|
||||
self, buffers: StreamBuffers, trace_state: TraceState, aggregate: AggregatedResult
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
for kind, segment in buffers.think_parser.flush():
|
||||
if not segment:
|
||||
continue
|
||||
if kind == "thought":
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
buffers.current_turn_reasoning.append(segment)
|
||||
buffers.pending_thought.append(segment)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
aggregate.text += segment
|
||||
buffers.pending_content.append(segment)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk=segment,
|
||||
is_final=False,
|
||||
)
|
||||
|
||||
if buffers.current_turn_reasoning:
|
||||
buffers.reasoning_per_turn.append("".join(buffers.current_turn_reasoning))
|
||||
|
||||
self._flush_thought_segment(buffers, trace_state)
|
||||
self._flush_content_segment(buffers, trace_state)
|
||||
|
||||
def _close_streams(self) -> Generator[NodeEventBase, None, None]:
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "generation", "content"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ThoughtChunkEvent(
|
||||
selector=[self._node_id, "generation", "thought"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolCallChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call=ToolCall(
|
||||
id="",
|
||||
name="",
|
||||
arguments="",
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
yield ToolResultChunkEvent(
|
||||
selector=[self._node_id, "generation", "tool_results"],
|
||||
chunk="",
|
||||
tool_result=ToolResult(
|
||||
id="",
|
||||
name="",
|
||||
output="",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
def _build_generation_data(
|
||||
self,
|
||||
trace_state: TraceState,
|
||||
agent_context: AgentContext,
|
||||
aggregate: AggregatedResult,
|
||||
buffers: StreamBuffers,
|
||||
) -> LLMGenerationData:
|
||||
sequence: list[dict[str, Any]] = []
|
||||
reasoning_index = 0
|
||||
content_position = 0
|
||||
tool_call_seen_index: dict[str, int] = {}
|
||||
for trace_segment in trace_state.trace_segments:
|
||||
if trace_segment.type == "thought":
|
||||
sequence.append({"type": "reasoning", "index": reasoning_index})
|
||||
reasoning_index += 1
|
||||
elif trace_segment.type == "content":
|
||||
segment_text = trace_segment.text or ""
|
||||
start = content_position
|
||||
end = start + len(segment_text)
|
||||
sequence.append({"type": "content", "start": start, "end": end})
|
||||
content_position = end
|
||||
elif trace_segment.type == "tool_call":
|
||||
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
|
||||
if tool_id not in tool_call_seen_index:
|
||||
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
|
||||
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
|
||||
|
||||
tool_calls_for_generation: list[ToolCallResult] = []
|
||||
for log in agent_context.agent_logs:
|
||||
payload = ToolLogPayload.from_mapping(log.data or {})
|
||||
tool_call_id = payload.tool_call_id
|
||||
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
|
||||
continue
|
||||
|
||||
tool_args = payload.tool_args
|
||||
log_error = payload.tool_error
|
||||
log_output = payload.tool_output
|
||||
result_text = log_output or log_error or ""
|
||||
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
|
||||
tool_calls_for_generation.append(
|
||||
ToolCallResult(
|
||||
id=tool_call_id,
|
||||
name=payload.tool_name,
|
||||
arguments=json.dumps(tool_args) if tool_args else "",
|
||||
output=result_text,
|
||||
status=status,
|
||||
)
|
||||
)
|
||||
|
||||
tool_calls_for_generation.sort(
|
||||
key=lambda item: trace_state.tool_call_index_map.get(item.id or "", len(trace_state.tool_call_index_map))
|
||||
)
|
||||
|
||||
return LLMGenerationData(
|
||||
text=aggregate.text,
|
||||
reasoning_contents=buffers.reasoning_per_turn,
|
||||
tool_calls=tool_calls_for_generation,
|
||||
sequence=sequence,
|
||||
usage=aggregate.usage,
|
||||
finish_reason=aggregate.finish_reason,
|
||||
files=aggregate.files,
|
||||
trace=trace_state.trace_segments,
|
||||
)
|
||||
|
||||
def _process_tool_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
) -> Generator[NodeEventBase, None, LLMGenerationData]:
|
||||
"""Process strategy outputs and convert to node events."""
|
||||
state = ToolOutputState()
|
||||
|
||||
try:
|
||||
for output in outputs:
|
||||
if isinstance(output, AgentLog):
|
||||
yield from self._handle_agent_log_output(output, state.stream, state.trace, state.agent)
|
||||
else:
|
||||
yield from self._handle_llm_chunk_output(output, state.stream, state.trace, state.aggregate)
|
||||
except StopIteration as exception:
|
||||
if isinstance(getattr(exception, "value", None), AgentResult):
|
||||
state.agent.agent_result = exception.value
|
||||
|
||||
if state.agent.agent_result:
|
||||
state.aggregate.text = state.agent.agent_result.text or state.aggregate.text
|
||||
state.aggregate.files = state.agent.agent_result.files
|
||||
if state.agent.agent_result.usage:
|
||||
state.aggregate.usage = state.agent.agent_result.usage
|
||||
if state.agent.agent_result.finish_reason:
|
||||
state.aggregate.finish_reason = state.agent.agent_result.finish_reason
|
||||
|
||||
yield from self._flush_remaining_stream(state.stream, state.trace, state.aggregate)
|
||||
yield from self._close_streams()
|
||||
|
||||
return self._build_generation_data(state.trace, state.agent, state.aggregate, state.stream)
|
||||
|
||||
def _accumulate_usage(self, total_usage: LLMUsage, delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
total_usage.prompt_tokens += delta_usage.prompt_tokens
|
||||
total_usage.completion_tokens += delta_usage.completion_tokens
|
||||
total_usage.total_tokens += delta_usage.total_tokens
|
||||
total_usage.prompt_price += delta_usage.prompt_price
|
||||
total_usage.completion_price += delta_usage.completion_price
|
||||
total_usage.total_price += delta_usage.total_price
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
|
||||
Reference in New Issue
Block a user