refactor(llm node): tool call tool result entity

This commit is contained in:
Novice
2025-12-17 10:30:21 +08:00
parent dd0a870969
commit d3486cab31
17 changed files with 300 additions and 169 deletions

View File

@ -1,11 +1,16 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@ -0,0 +1,33 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")

View File

@ -16,7 +16,13 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@ -321,7 +327,9 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
**extra_fields,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
@ -334,7 +342,9 @@ class ResponseStreamCoordinator:
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
**extra_fields: Additional fields for specialized events (chunk_type, tool_call_id, etc.)
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@ -347,7 +357,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
**extra_fields,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@ -359,7 +371,9 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
**extra_fields,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@ -436,11 +450,8 @@ class ResponseStreamCoordinator:
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
events.append(updated_event)
else:

View File

@ -45,6 +45,8 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
@ -75,4 +77,6 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@ -5,7 +5,7 @@ from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@ -43,13 +43,16 @@ class NodeRunStreamChunkEvent(GraphNodeEventBase):
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = Field(default=None, description="unique identifier for this tool call")
tool_name: str | None = Field(default=None, description="name of the tool being called")
tool_arguments: str | None = Field(default=None, description="accumulated tool arguments JSON")
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
tool_error: str | None = Field(default=None, description="error message if tool failed")
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@ -7,6 +7,7 @@ from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@ -51,25 +52,22 @@ class StreamChunkEvent(NodeEventBase):
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call_id: str = Field(..., description="unique identifier for this tool call")
tool_name: str = Field(..., description="name of the tool being called")
tool_arguments: str = Field(default="", description="accumulated tool arguments JSON")
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_call_id: str = Field(..., description="identifier of the tool call this result belongs to")
tool_name: str = Field(..., description="name of the tool")
tool_files: list[str] = Field(default_factory=list, description="file IDs produced by tool")
tool_error: str | None = Field(default=None, description="error message if tool failed")
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtChunkEvent(StreamChunkEvent):

View File

@ -556,6 +556,8 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
@_dispatch.register
@ -570,14 +572,18 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_CALL,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_arguments=event.tool_arguments,
tool_call=event.tool_call,
)
@_dispatch.register
def _(self, event: ToolResultChunkEvent) -> NodeRunStreamChunkEvent:
from core.workflow.graph_events import ChunkType
from core.workflow.entities import ToolResult
from core.workflow.graph_events import ChunkType, ToolResultStatus
tool_result = event.tool_result
status: ToolResultStatus = (
tool_result.status if tool_result and tool_result.status is not None else ToolResultStatus.SUCCESS
)
return NodeRunStreamChunkEvent(
id=self._node_execution_id,
@ -587,10 +593,13 @@ class Node(Generic[NodeDataT]):
chunk=event.chunk,
is_final=event.is_final,
chunk_type=ChunkType.TOOL_RESULT,
tool_call_id=event.tool_call_id,
tool_name=event.tool_name,
tool_files=event.tool_files,
tool_error=event.tool_error,
tool_result=ToolResult(
id=tool_result.id if tool_result else None,
name=tool_result.name if tool_result else None,
output=tool_result.output if tool_result else None,
files=tool_result.files if tool_result else [],
status=status,
),
)
@_dispatch.register

View File

@ -8,6 +8,7 @@ from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
from core.model_runtime.entities.llm_entities import LLMUsage
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.entities import ToolCallResult
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.entities import VariableSelector
@ -33,12 +34,10 @@ class LLMTraceSegment(BaseModel):
text: str | None = Field(None, description="Text chunk for thought/content")
# Tool call fields (combined start + result)
tool_call_id: str | None = None
tool_name: str | None = None
tool_arguments: str | None = None
tool_output: str | None = None
tool_error: str | None = None
files: list[str] = Field(default_factory=list, description="File IDs from tool result if any")
tool_call: ToolCallResult | None = Field(
default=None,
description="Combined tool call arguments and result for this segment",
)
class LLMGenerationData(BaseModel):
@ -51,7 +50,7 @@ class LLMGenerationData(BaseModel):
text: str = Field(..., description="Accumulated text content from all turns")
reasoning_contents: list[str] = Field(default_factory=list, description="Reasoning content per turn")
tool_calls: list[dict[str, Any]] = Field(default_factory=list, description="Tool calls with results")
tool_calls: list[ToolCallResult] = Field(default_factory=list, description="Tool calls with results")
sequence: list[dict[str, Any]] = Field(default_factory=list, description="Ordered segments for rendering")
usage: LLMUsage = Field(..., description="LLM usage statistics")
finish_reason: str | None = Field(None, description="Finish reason from LLM")

View File

@ -60,7 +60,8 @@ from core.variables import (
StringSegment,
)
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities import GraphInitParams, ToolCall, ToolResult, ToolResultStatus
from core.workflow.entities.tool_entities import ToolCallResult
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@ -1671,9 +1672,11 @@ class LLMNode(Node[LLMNodeData]):
tool_call_segment = LLMTraceSegment(
type="tool_call",
text=None,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
)
trace_segments.append(tool_call_segment)
if tool_call_id:
@ -1682,9 +1685,11 @@ class LLMNode(Node[LLMNodeData]):
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk=tool_arguments,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_call=ToolCall(
id=tool_call_id,
name=tool_name,
arguments=tool_arguments,
),
is_final=False,
)
@ -1724,27 +1729,50 @@ class LLMNode(Node[LLMNodeData]):
tool_call_segment = existing_tool_segment or LLMTraceSegment(
type="tool_call",
text=None,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=None,
tool_call=ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
),
)
if existing_tool_segment is None:
trace_segments.append(tool_call_segment)
if tool_call_id:
tool_trace_map[tool_call_id] = tool_call_segment
tool_call_segment.tool_output = str(tool_output) if tool_output is not None else None
tool_call_segment.tool_error = str(tool_error) if tool_error is not None else None
tool_call_segment.files = [str(f) for f in tool_files] if tool_files else []
if tool_call_segment.tool_call is None:
tool_call_segment.tool_call = ToolCallResult(
id=tool_call_id,
name=tool_name,
arguments=None,
)
tool_call_segment.tool_call.output = (
str(tool_output)
if tool_output is not None
else str(tool_error)
if tool_error is not None
else None
)
tool_call_segment.tool_call.files = []
tool_call_segment.tool_call.status = (
ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS
)
current_turn += 1
result_output = (
str(tool_output) if tool_output is not None else str(tool_error) if tool_error else None
)
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk=str(tool_output) if tool_output else "",
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_files=tool_files,
tool_error=tool_error,
chunk=result_output or "",
tool_result=ToolResult(
id=tool_call_id,
name=tool_name,
output=result_output,
files=tool_files,
status=ToolResultStatus.ERROR if tool_error else ToolResultStatus.SUCCESS,
),
is_final=False,
)
@ -1865,7 +1893,7 @@ class LLMNode(Node[LLMNodeData]):
sequence.append({"type": "content", "start": start, "end": end})
content_position = end
elif trace_segment.type == "tool_call":
tool_id = trace_segment.tool_call_id or ""
tool_id = trace_segment.tool_call.id if trace_segment.tool_call and trace_segment.tool_call.id else ""
if tool_id not in tool_call_seen_index:
tool_call_seen_index[tool_id] = len(tool_call_seen_index)
sequence.append({"type": "tool_call", "index": tool_call_seen_index[tool_id]})
@ -1893,9 +1921,11 @@ class LLMNode(Node[LLMNodeData]):
yield ToolCallChunkEvent(
selector=[self._node_id, "generation", "tool_calls"],
chunk="",
tool_call_id="",
tool_name="",
tool_arguments="",
tool_call=ToolCall(
id="",
name="",
arguments="",
),
is_final=True,
)
@ -1903,33 +1933,40 @@ class LLMNode(Node[LLMNodeData]):
yield ToolResultChunkEvent(
selector=[self._node_id, "generation", "tool_results"],
chunk="",
tool_call_id="",
tool_name="",
tool_files=[],
tool_error=None,
tool_result=ToolResult(
id="",
name="",
output="",
files=[],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
# Build tool_calls from agent_logs (with results)
tool_calls_for_generation = []
tool_calls_for_generation: list[ToolCallResult] = []
for log in agent_logs:
tool_call_id = log.data.get("tool_call_id")
if not tool_call_id or log.status == AgentLog.LogStatus.START.value:
continue
tool_args = log.data.get("tool_args") or {}
log_error = log.data.get("error")
log_output = log.data.get("output")
result_text = log_output or log_error or ""
status = ToolResultStatus.ERROR if log_error else ToolResultStatus.SUCCESS
tool_calls_for_generation.append(
{
"id": tool_call_id,
"name": log.data.get("tool_name", ""),
"arguments": json.dumps(tool_args) if tool_args else "",
# Prefer output, fall back to error text if present
"result": log.data.get("output") or log.data.get("error") or "",
}
ToolCallResult(
id=tool_call_id,
name=log.data.get("tool_name", ""),
arguments=json.dumps(tool_args) if tool_args else "",
output=result_text,
status=status,
)
)
tool_calls_for_generation.sort(
key=lambda item: tool_call_index_map.get(item.get("id", ""), len(tool_call_index_map))
key=lambda item: tool_call_index_map.get(item.id or "", len(tool_call_index_map))
)
# Return generation data for caller