mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor(llm node): tool call tool result entity
This commit is contained in:
@ -2,10 +2,16 @@
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.workflow.entities import ToolResultStatus
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from core.workflow.graph_events import ChunkType, NodeRunStreamChunkEvent
|
||||
from core.workflow.graph_events import (
|
||||
ChunkType,
|
||||
NodeRunStreamChunkEvent,
|
||||
ToolCall,
|
||||
ToolResult,
|
||||
)
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@ -80,9 +86,11 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
chunk='{"query": "test"}',
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_CALL,
|
||||
tool_call_id="call_123",
|
||||
tool_name="search",
|
||||
tool_arguments='{"query": "test"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Tool result stream
|
||||
@ -94,10 +102,13 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
chunk="Found 10 results",
|
||||
is_final=True,
|
||||
chunk_type=ChunkType.TOOL_RESULT,
|
||||
tool_call_id="call_123",
|
||||
tool_name="search",
|
||||
tool_files=[],
|
||||
tool_error=None,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="search",
|
||||
output="Found 10 results",
|
||||
files=[],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
)
|
||||
|
||||
# Intercept these events
|
||||
@ -111,6 +122,14 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
|
||||
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
|
||||
|
||||
# Verify payloads are preserved in buffered events
|
||||
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
|
||||
assert buffered_call.tool_call is not None
|
||||
assert buffered_call.tool_call.id == "call_123"
|
||||
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
|
||||
assert buffered_result.tool_result is not None
|
||||
assert buffered_result.tool_result.status == "success"
|
||||
|
||||
# Verify we can find child streams
|
||||
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
|
||||
assert len(child_streams) == 3
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
"""Tests for StreamChunkEvent and its subclasses."""
|
||||
|
||||
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
|
||||
from core.workflow.node_events import (
|
||||
ChunkType,
|
||||
StreamChunkEvent,
|
||||
@ -87,14 +88,13 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_calls"]
|
||||
assert event.chunk == '{"city": "Beijing"}'
|
||||
assert event.tool_call_id == "call_123"
|
||||
assert event.tool_name == "weather"
|
||||
assert event.tool_call.id == "call_123"
|
||||
assert event.tool_call.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
|
||||
def test_chunk_type_is_tool_call(self):
|
||||
@ -102,8 +102,7 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_CALL
|
||||
@ -113,30 +112,34 @@ class TestToolCallChunkEvent:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"param": "value"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_arguments='{"param": "value"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
arguments='{"param": "value"}',
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_arguments == '{"param": "value"}'
|
||||
assert event.tool_call.arguments == '{"param": "value"}'
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk='{"city": "Beijing"}',
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_arguments='{"city": "Beijing"}',
|
||||
tool_call=ToolCall(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
arguments='{"city": "Beijing"}',
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_call"
|
||||
assert data["tool_call_id"] == "call_123"
|
||||
assert data["tool_name"] == "weather"
|
||||
assert data["tool_arguments"] == '{"city": "Beijing"}'
|
||||
assert data["tool_call"]["id"] == "call_123"
|
||||
assert data["tool_call"]["name"] == "weather"
|
||||
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
@ -148,14 +151,13 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny, 25°C",
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
|
||||
)
|
||||
|
||||
assert event.selector == ["node1", "tool_results"]
|
||||
assert event.chunk == "Weather: Sunny, 25°C"
|
||||
assert event.tool_call_id == "call_123"
|
||||
assert event.tool_name == "weather"
|
||||
assert event.tool_result.id == "call_123"
|
||||
assert event.tool_result.name == "weather"
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
|
||||
def test_chunk_type_is_tool_result(self):
|
||||
@ -163,8 +165,7 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.chunk_type == ChunkType.TOOL_RESULT
|
||||
@ -174,55 +175,62 @@ class TestToolResultChunkEvent:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_result=ToolResult(id="call_123", name="test_tool"),
|
||||
)
|
||||
|
||||
assert event.tool_files == []
|
||||
assert event.tool_result.files == []
|
||||
|
||||
def test_tool_files_with_values(self):
|
||||
"""Test tool_files with file IDs."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_files=["file_1", "file_2"],
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
files=["file_1", "file_2"],
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_files == ["file_1", "file_2"]
|
||||
assert event.tool_result.files == ["file_1", "file_2"]
|
||||
|
||||
def test_tool_error_field(self):
|
||||
"""Test tool_error field."""
|
||||
def test_tool_error_output(self):
|
||||
"""Test error output captured in tool_result."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test_tool",
|
||||
tool_error="Tool execution failed",
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="test_tool",
|
||||
output="Tool execution failed",
|
||||
status=ToolResultStatus.ERROR,
|
||||
),
|
||||
)
|
||||
|
||||
assert event.tool_error == "Tool execution failed"
|
||||
assert event.tool_result.output == "Tool execution failed"
|
||||
assert event.tool_result.status == ToolResultStatus.ERROR
|
||||
|
||||
def test_serialization(self):
|
||||
"""Test that event can be serialized to dict."""
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="Weather: Sunny",
|
||||
tool_call_id="call_123",
|
||||
tool_name="weather",
|
||||
tool_files=["file_1"],
|
||||
tool_error=None,
|
||||
tool_result=ToolResult(
|
||||
id="call_123",
|
||||
name="weather",
|
||||
output="Weather: Sunny",
|
||||
files=["file_1"],
|
||||
status=ToolResultStatus.SUCCESS,
|
||||
),
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
data = event.model_dump()
|
||||
|
||||
assert data["chunk_type"] == "tool_result"
|
||||
assert data["tool_call_id"] == "call_123"
|
||||
assert data["tool_name"] == "weather"
|
||||
assert data["tool_files"] == ["file_1"]
|
||||
assert data["tool_error"] is None
|
||||
assert data["tool_result"]["id"] == "call_123"
|
||||
assert data["tool_result"]["name"] == "weather"
|
||||
assert data["tool_result"]["files"] == ["file_1"]
|
||||
assert data["is_final"] is True
|
||||
|
||||
|
||||
@ -272,8 +280,7 @@ class TestEventInheritance:
|
||||
event = ToolCallChunkEvent(
|
||||
selector=["node1", "tool_calls"],
|
||||
chunk="",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test",
|
||||
tool_call=ToolCall(id="call_123", name="test", arguments=None),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
@ -283,8 +290,7 @@ class TestEventInheritance:
|
||||
event = ToolResultChunkEvent(
|
||||
selector=["node1", "tool_results"],
|
||||
chunk="result",
|
||||
tool_call_id="call_123",
|
||||
tool_name="test",
|
||||
tool_result=ToolResult(id="call_123", name="test"),
|
||||
)
|
||||
|
||||
assert isinstance(event, StreamChunkEvent)
|
||||
@ -302,8 +308,16 @@ class TestEventInheritance:
|
||||
"""Test that all events have common StreamChunkEvent fields."""
|
||||
events = [
|
||||
StreamChunkEvent(selector=["n", "t"], chunk="a"),
|
||||
ToolCallChunkEvent(selector=["n", "t"], chunk="b", tool_call_id="1", tool_name="t"),
|
||||
ToolResultChunkEvent(selector=["n", "t"], chunk="c", tool_call_id="1", tool_name="t"),
|
||||
ToolCallChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="b",
|
||||
tool_call=ToolCall(id="1", name="t", arguments=None),
|
||||
),
|
||||
ToolResultChunkEvent(
|
||||
selector=["n", "t"],
|
||||
chunk="c",
|
||||
tool_result=ToolResult(id="1", name="t"),
|
||||
),
|
||||
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user