refactor(llm node): tool call tool result entity

This commit is contained in:
Novice
2025-12-17 10:30:21 +08:00
parent dd0a870969
commit d3486cab31
17 changed files with 300 additions and 169 deletions

View File

@ -2,10 +2,16 @@
from unittest.mock import MagicMock
from core.workflow.entities import ToolResultStatus
from core.workflow.enums import NodeType
from core.workflow.graph import Graph
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.runtime import VariablePool
@ -80,9 +86,11 @@ class TestResponseCoordinatorObjectStreaming:
chunk='{"query": "test"}',
is_final=True,
chunk_type=ChunkType.TOOL_CALL,
tool_call_id="call_123",
tool_name="search",
tool_arguments='{"query": "test"}',
tool_call=ToolCall(
id="call_123",
name="search",
arguments='{"query": "test"}',
),
)
# 3. Tool result stream
@ -94,10 +102,13 @@ class TestResponseCoordinatorObjectStreaming:
chunk="Found 10 results",
is_final=True,
chunk_type=ChunkType.TOOL_RESULT,
tool_call_id="call_123",
tool_name="search",
tool_files=[],
tool_error=None,
tool_result=ToolResult(
id="call_123",
name="search",
output="Found 10 results",
files=[],
status=ToolResultStatus.SUCCESS,
),
)
# Intercept these events
@ -111,6 +122,14 @@ class TestResponseCoordinatorObjectStreaming:
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
# Verify payloads are preserved in buffered events
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
assert buffered_call.tool_call is not None
assert buffered_call.tool_call.id == "call_123"
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
assert buffered_result.tool_result is not None
assert buffered_result.tool_result.status == "success"
# Verify we can find child streams
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
assert len(child_streams) == 3

View File

@ -1,5 +1,6 @@
"""Tests for StreamChunkEvent and its subclasses."""
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
from core.workflow.node_events import (
ChunkType,
StreamChunkEvent,
@ -87,14 +88,13 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call_id="call_123",
tool_name="weather",
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
)
assert event.selector == ["node1", "tool_calls"]
assert event.chunk == '{"city": "Beijing"}'
assert event.tool_call_id == "call_123"
assert event.tool_name == "weather"
assert event.tool_call.id == "call_123"
assert event.tool_call.name == "weather"
assert event.chunk_type == ChunkType.TOOL_CALL
def test_chunk_type_is_tool_call(self):
@ -102,8 +102,7 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call_id="call_123",
tool_name="test_tool",
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
)
assert event.chunk_type == ChunkType.TOOL_CALL
@ -113,30 +112,34 @@ class TestToolCallChunkEvent:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"param": "value"}',
tool_call_id="call_123",
tool_name="test_tool",
tool_arguments='{"param": "value"}',
tool_call=ToolCall(
id="call_123",
name="test_tool",
arguments='{"param": "value"}',
),
)
assert event.tool_arguments == '{"param": "value"}'
assert event.tool_call.arguments == '{"param": "value"}'
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call_id="call_123",
tool_name="weather",
tool_arguments='{"city": "Beijing"}',
tool_call=ToolCall(
id="call_123",
name="weather",
arguments='{"city": "Beijing"}',
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_call"
assert data["tool_call_id"] == "call_123"
assert data["tool_name"] == "weather"
assert data["tool_arguments"] == '{"city": "Beijing"}'
assert data["tool_call"]["id"] == "call_123"
assert data["tool_call"]["name"] == "weather"
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
assert data["is_final"] is True
@ -148,14 +151,13 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny, 25°C",
tool_call_id="call_123",
tool_name="weather",
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
)
assert event.selector == ["node1", "tool_results"]
assert event.chunk == "Weather: Sunny, 25°C"
assert event.tool_call_id == "call_123"
assert event.tool_name == "weather"
assert event.tool_result.id == "call_123"
assert event.tool_result.name == "weather"
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_chunk_type_is_tool_result(self):
@ -163,8 +165,7 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.chunk_type == ChunkType.TOOL_RESULT
@ -174,55 +175,62 @@ class TestToolResultChunkEvent:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.tool_files == []
assert event.tool_result.files == []
def test_tool_files_with_values(self):
"""Test tool_files with file IDs."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test_tool",
tool_files=["file_1", "file_2"],
tool_result=ToolResult(
id="call_123",
name="test_tool",
files=["file_1", "file_2"],
),
)
assert event.tool_files == ["file_1", "file_2"]
assert event.tool_result.files == ["file_1", "file_2"]
def test_tool_error_field(self):
"""Test tool_error field."""
def test_tool_error_output(self):
"""Test error output captured in tool_result."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="",
tool_call_id="call_123",
tool_name="test_tool",
tool_error="Tool execution failed",
tool_result=ToolResult(
id="call_123",
name="test_tool",
output="Tool execution failed",
status=ToolResultStatus.ERROR,
),
)
assert event.tool_error == "Tool execution failed"
assert event.tool_result.output == "Tool execution failed"
assert event.tool_result.status == ToolResultStatus.ERROR
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny",
tool_call_id="call_123",
tool_name="weather",
tool_files=["file_1"],
tool_error=None,
tool_result=ToolResult(
id="call_123",
name="weather",
output="Weather: Sunny",
files=["file_1"],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_result"
assert data["tool_call_id"] == "call_123"
assert data["tool_name"] == "weather"
assert data["tool_files"] == ["file_1"]
assert data["tool_error"] is None
assert data["tool_result"]["id"] == "call_123"
assert data["tool_result"]["name"] == "weather"
assert data["tool_result"]["files"] == ["file_1"]
assert data["is_final"] is True
@ -272,8 +280,7 @@ class TestEventInheritance:
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call_id="call_123",
tool_name="test",
tool_call=ToolCall(id="call_123", name="test", arguments=None),
)
assert isinstance(event, StreamChunkEvent)
@ -283,8 +290,7 @@ class TestEventInheritance:
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_call_id="call_123",
tool_name="test",
tool_result=ToolResult(id="call_123", name="test"),
)
assert isinstance(event, StreamChunkEvent)
@ -302,8 +308,16 @@ class TestEventInheritance:
"""Test that all events have common StreamChunkEvent fields."""
events = [
StreamChunkEvent(selector=["n", "t"], chunk="a"),
ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"),
ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"),
ToolCallChunkEvent(
selector=["n", "t"],
chunk="b",
tool_call=ToolCall(id="1", name="t", arguments=None),
),
ToolResultChunkEvent(
selector=["n", "t"],
chunk="c",
tool_result=ToolResult(id="1", name="t"),
),
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
]