Merge branch 'feat/agent-node-v2' into deploy/dev

This commit is contained in:
Novice
2025-12-30 13:43:40 +08:00
61 changed files with 6527 additions and 1246 deletions

View File

@ -0,0 +1,4 @@
"""
Mark agent test modules as a package to avoid import name collisions.
"""

View File

@ -0,0 +1,324 @@
"""Tests for AgentPattern base class."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.agent.patterns.base import AgentPattern
from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
"""Minimal implementation for testing."""
yield from []
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def agent_pattern(mock_model_instance, mock_context):
"""Create a concrete agent pattern for testing."""
return ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=10,
)
class TestAccumulateUsage:
"""Tests for _accumulate_usage method."""
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
"""Test accumulating usage to an empty dict creates a copy."""
total_usage: dict = {"usage": None}
delta_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"] is not None
assert total_usage["usage"].total_tokens == 150
assert total_usage["usage"].prompt_tokens == 100
assert total_usage["usage"].completion_tokens == 50
# Verify it's a copy, not a reference
assert total_usage["usage"] is not delta_usage
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
"""Test accumulating usage adds to existing values."""
initial_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
total_usage: dict = {"usage": initial_usage}
delta_usage = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"].total_tokens == 450 # 150 + 300
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
"""Test accumulating usage across multiple rounds."""
total_usage: dict = {"usage": None}
# Round 1: 100 tokens
round1_usage = LLMUsage(
prompt_tokens=70,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.07"),
completion_tokens=30,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.06"),
total_tokens=100,
total_price=Decimal("0.13"),
currency="USD",
latency=0.3,
)
agent_pattern._accumulate_usage(total_usage, round1_usage)
assert total_usage["usage"].total_tokens == 100
# Round 2: 150 tokens
round2_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.4,
)
agent_pattern._accumulate_usage(total_usage, round2_usage)
assert total_usage["usage"].total_tokens == 250 # 100 + 150
# Round 3: 200 tokens
round3_usage = LLMUsage(
prompt_tokens=130,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.13"),
completion_tokens=70,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.14"),
total_tokens=200,
total_price=Decimal("0.27"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, round3_usage)
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
class TestCreateLog:
"""Tests for _create_log method."""
def test_create_log_with_label_and_status(self, agent_pattern):
"""Test creating a log with label and status."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.parent_id is None
def test_create_log_with_parent_id(self, agent_pattern):
"""Test creating a log with parent_id."""
parent_log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
child_log = agent_pattern._create_log(
label="CALL tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={},
parent_id=parent_log.id,
)
assert child_log.parent_id == parent_log.id
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
class TestFinishLog:
"""Tests for _finish_log method."""
def test_finish_log_updates_status(self, agent_pattern):
"""Test that finish_log updates status to SUCCESS."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
assert finished_log.status == AgentLog.LogStatus.SUCCESS
assert finished_log.data == {"result": "done"}
def test_finish_log_adds_usage_metadata(self, agent_pattern):
"""Test that finish_log adds usage to metadata."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
finished_log = agent_pattern._finish_log(log, usage=usage)
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
class TestFindToolByName:
"""Tests for _find_tool_by_name method."""
def test_find_existing_tool(self, mock_model_instance, mock_context):
"""Test finding an existing tool by name."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("test_tool")
assert found_tool == mock_tool
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
"""Test that finding a nonexistent tool returns None."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("nonexistent_tool")
assert found_tool is None
class TestMaxIterationsCapping:
"""Tests for max_iterations capping."""
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is capped at 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=150,
)
assert pattern.max_iterations == 99
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is not capped when under 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=50,
)
assert pattern.max_iterations == 50

View File

@ -0,0 +1,332 @@
"""Tests for FunctionCallStrategy."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.to_prompt_message_tool.return_value = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={
"type": "object",
"properties": {"param1": {"type": "string", "description": "A parameter"}},
"required": ["param1"],
},
)
return tool
class TestFunctionCallStrategyInit:
"""Tests for FunctionCallStrategy initialization."""
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
"""Test basic initialization."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=10,
)
assert strategy.model_instance == mock_model_instance
assert strategy.context == mock_context
assert strategy.max_iterations == 10
assert len(strategy.tools) == 1
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test initialization with tool_invoke_hook."""
from core.agent.patterns.function_call import FunctionCallStrategy
mock_hook = MagicMock()
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook
class TestConvertToolsToPromptFormat:
"""Tests for _convert_tools_to_prompt_format method."""
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert len(tools) == 1
assert isinstance(tools[0], PromptMessageTool)
assert tools[0].name == "test_tool"
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert tools == []
class TestAgentLogGeneration:
"""Tests for AgentLog generation during run."""
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that round logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=1,
)
# Create a round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"inputs": {"query": "test"}},
)
assert round_log.label == "ROUND 1"
assert round_log.log_type == AgentLog.LogType.ROUND
assert round_log.status == AgentLog.LogStatus.START
assert "inputs" in round_log.data
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool call logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
# Create a parent round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
# Create a tool call log
tool_log = strategy._create_log(
label="CALL test_tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
parent_id=round_log.id,
)
assert tool_log.label == "CALL test_tool"
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
assert tool_log.parent_id == round_log.id
assert tool_log.data["tool_name"] == "test_tool"
class TestToolInvocation:
"""Tests for tool invocation."""
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool invocation uses hook when provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_hook = MagicMock()
mock_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
)
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
mock_hook.assert_called_once()
assert result == "Tool result"
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
assert meta == mock_meta
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool_invoke_hook is None when not provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=None,
)
# Verify that tool_invoke_hook is None
assert strategy.tool_invoke_hook is None
class TestUsageTracking:
"""Tests for usage tracking across rounds."""
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
"""Test that round usage is tracked separately from total."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
# Simulate two rounds of usage
total_usage: dict = {"usage": None}
round1_usage: dict = {"usage": None}
round2_usage: dict = {"usage": None}
# Round 1
usage1 = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round1_usage, usage1)
strategy._accumulate_usage(total_usage, usage1)
# Round 2
usage2 = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round2_usage, usage2)
strategy._accumulate_usage(total_usage, usage2)
# Verify round usage is separate
assert round1_usage["usage"].total_tokens == 150
assert round2_usage["usage"].total_tokens == 300
# Verify total is accumulated
assert total_usage["usage"].total_tokens == 450
class TestPromptMessageHandling:
"""Tests for prompt message handling."""
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
"""Test that messages include system and user prompts."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello"),
]
# Just verify the messages can be processed
assert len(messages) == 2
assert isinstance(messages[0], SystemPromptMessage)
assert isinstance(messages[1], UserPromptMessage)
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
"""Test that assistant messages can contain tool calls."""
from core.model_runtime.entities.message_entities import AssistantPromptMessage
tool_call = AssistantPromptMessage.ToolCall(
id="call_123",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="test_tool",
arguments='{"param1": "value1"}',
),
)
assistant_message = AssistantPromptMessage(
content="I'll help you with that.",
tool_calls=[tool_call],
)
assert len(assistant_message.tool_calls) == 1
assert assistant_message.tool_calls[0].function.name == "test_tool"

View File

@ -0,0 +1,224 @@
"""Tests for ReActStrategy."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import ExecutionContext
from core.agent.patterns.react import ReActStrategy
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
from core.model_runtime.entities.message_entities import PromptMessageTool
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.entity.identity.provider = "test_provider"
# Use real PromptMessageTool for proper serialization
prompt_tool = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={"type": "object", "properties": {}},
)
tool.to_prompt_message_tool.return_value = prompt_tool
return tool
class TestReActStrategyInit:
"""Tests for ReActStrategy initialization."""
def test_init_with_instruction(self, mock_model_instance, mock_context):
"""Test that instruction is stored correctly."""
instruction = "You are a helpful assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
assert strategy.instruction == instruction
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
"""Test that empty instruction is handled correctly."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
assert strategy.instruction == ""
class TestBuildPromptWithReactFormat:
"""Tests for _build_prompt_with_react_format method."""
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tools}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
UserPromptMessage(content="Hello"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# The tools placeholder should be replaced with JSON
assert "{{tools}}" not in result[0].content
assert "test_tool" in result[0].content
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tool_names}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "Valid actions: {{tool_names}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
assert "{{tool_names}}" not in result[0].content
assert '"test_tool"' in result[0].content
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
"""Test that {{instruction}} placeholder is replaced."""
instruction = "You are a helpful coding assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
assert "{{instruction}}" not in result[0].content
assert instruction in result[0].content
def test_no_tools_available_message(self, mock_model_instance, mock_context):
"""Test that 'No tools available' is shown when include_tools is False."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], False)
assert "No tools available" in result[0].content
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities import AssistantPromptMessage
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
scratchpad = [
AgentScratchpadUnit(
thought="I need to search for information",
action_str='{"action": "search", "action_input": "query"}',
observation="Search results here",
)
]
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
# The last message should be an AssistantPromptMessage with scratchpad content
assert len(result) == 3
assert isinstance(result[-1], AssistantPromptMessage)
assert "I need to search for information" in result[-1].content
assert "Search results here" in result[-1].content
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
"""Test that empty scratchpad doesn't add extra message."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# Should only have the original 2 messages
assert len(result) == 2
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
"""Test that original messages list is not modified."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
original_content = "Original system prompt {{tools}}"
messages = [
SystemPromptMessage(content=original_content),
]
strategy._build_prompt_with_react_format(messages, [], True)
# Original message should not be modified
assert messages[0].content == original_content

View File

@ -0,0 +1,203 @@
"""Tests for StrategyFactory."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentEntity, ExecutionContext
from core.agent.patterns.function_call import FunctionCallStrategy
from core.agent.patterns.react import ReActStrategy
from core.agent.patterns.strategy_factory import StrategyFactory
from core.model_runtime.entities.model_entities import ModelFeature
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
class TestStrategyFactory:
"""Tests for StrategyFactory.create_strategy method."""
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
model_features = [ModelFeature.MULTI_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
model_features = [ModelFeature.STREAM_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model doesn't support tool calling."""
model_features = [ModelFeature.VISION] # Only vision, no tool calling
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model has no features."""
model_features: list[ModelFeature] = []
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
assert isinstance(strategy, FunctionCallStrategy)
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
self, mock_model_instance, mock_context
):
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
model_features: list[ModelFeature] = [] # No tool calling support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
# Should fall back to ReAct since FC is not supported
assert isinstance(strategy, ReActStrategy)
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
)
assert isinstance(strategy, ReActStrategy)
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
"""Test that ReActStrategy receives instruction parameter."""
model_features: list[ModelFeature] = []
instruction = "You are a helpful assistant."
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
instruction=instruction,
)
assert isinstance(strategy, ReActStrategy)
assert strategy.instruction == instruction
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that max_iterations is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
max_iterations = 5
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
max_iterations=max_iterations,
)
assert strategy.max_iterations == max_iterations
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that tool_invoke_hook is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
mock_hook = MagicMock()
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook

View File

@ -0,0 +1,388 @@
"""Tests for AgentAppRunner."""
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.llm_entities import LLMUsage
class TestOrganizePromptMessages:
"""Tests for _organize_prompt_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
# We'll patch the class to avoid complex initialization
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
# Set up required attributes
runner.config = MagicMock(spec=AgentEntity)
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
runner.config.prompt = None
runner.app_config = MagicMock()
runner.app_config.prompt_template = MagicMock()
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
runner.history_prompt_messages = []
runner.query = "Hello"
runner._current_thoughts = []
runner.files = []
runner.model_config = MagicMock()
runner.memory = None
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_function_calling_uses_simple_prompt(self, mock_runner):
"""Test that function calling strategy uses simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
"""Test that chain of thought strategy uses agent prompt template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = AgentPromptEntity(
first_prompt="ReAct prompt template with {{tools}}",
next_iteration="Continue...",
)
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with agent prompt
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "ReAct prompt template with {{tools}}"
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = None
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
class TestInitSystemMessage:
"""Tests for _init_system_message method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_empty_messages_with_template(self, mock_runner):
"""Test that system message is created when messages are empty."""
result = mock_runner._init_system_message("System template", [])
assert len(result) == 1
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
def test_empty_messages_without_template(self, mock_runner):
"""Test that empty list is returned when no template and no messages."""
result = mock_runner._init_system_message("", [])
assert result == []
def test_existing_system_message_not_duplicated(self, mock_runner):
"""Test that system message is not duplicated if already present."""
existing_messages = [
SystemPromptMessage(content="Existing system"),
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("New template", existing_messages)
# Should not insert new system message
assert len(result) == 2
assert result[0].content == "Existing system"
def test_system_message_inserted_when_missing(self, mock_runner):
"""Test that system message is inserted when first message is not system."""
existing_messages = [
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("System template", existing_messages)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
class TestClearUserPromptImageMessages:
"""Tests for _clear_user_prompt_image_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_text_content_unchanged(self, mock_runner):
"""Test that text content is unchanged."""
messages = [
UserPromptMessage(content="Plain text message"),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
assert len(result) == 1
assert result[0].content == "Plain text message"
def test_original_messages_not_modified(self, mock_runner):
"""Test that original messages are not modified (deep copy)."""
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
)
messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="Text part"),
ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
),
]
),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
# Original should still have list content
assert isinstance(messages[0].content, list)
# Result should have string content
assert isinstance(result[0].content, str)
class TestToolInvokeHook:
"""Tests for _create_tool_invoke_hook method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.user_id = "test-user"
runner.tenant_id = "test-tenant"
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.trace_manager = None
runner.application_generate_entity.invoke_from = "api"
runner.application_generate_entity.app_config = MagicMock()
runner.application_generate_entity.app_config.app_id = "test-app"
runner.agent_callback = MagicMock()
runner.conversation = MagicMock()
runner.conversation.id = "test-conversation"
runner.queue_manager = MagicMock()
runner._current_message_file_ids = []
return runner
def test_hook_calls_agent_invoke(self, mock_runner):
"""Test that the hook calls ToolEngine.agent_invoke."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={
"tool_provider_type": "test_provider",
"tool_provider": "test_id",
},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
# Verify ToolEngine.agent_invoke was called
mock_engine.agent_invoke.assert_called_once()
# Verify return values
assert result_content == "Tool result"
assert result_files == ["file-1", "file-2"]
assert result_meta == mock_tool_meta
def test_hook_publishes_file_events(self, mock_runner):
"""Test that the hook publishes QueueMessageFileEvent for files."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
hook(mock_tool, {}, "test_tool")
# Verify file events were published
assert mock_runner.queue_manager.publish.call_count == 2
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
class TestAgentLogProcessing:
"""Tests for AgentLog processing in run method."""
def test_agent_log_status_enum(self):
"""Test AgentLog status enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_agent_log_metadata_enum(self):
"""Test AgentLog metadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
def test_agent_result_structure(self):
"""Test AgentResult structure."""
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
result = AgentResult(
text="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
assert result.text == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"
class TestOrganizeUserQuery:
"""Tests for _organize_user_query method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.files = []
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_simple_query_without_files(self, mock_runner):
"""Test organizing a simple query without files."""
result = mock_runner._organize_user_query("Hello world", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == "Hello world"
def test_query_with_files(self, mock_runner):
"""Test organizing a query with files."""
from core.file.models import File
mock_file = MagicMock(spec=File)
mock_runner.files = [mock_file]
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
)
result = mock_runner._organize_user_query("Describe this image", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert isinstance(result[0].content, list)
assert len(result[0].content) == 2 # Image + Text

View File

@ -0,0 +1,191 @@
"""Tests for agent entities."""
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
class TestExecutionContext:
"""Tests for ExecutionContext entity."""
def test_create_with_all_fields(self):
"""Test creating ExecutionContext with all fields."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
assert context.user_id == "user-123"
assert context.app_id == "app-456"
assert context.conversation_id == "conv-789"
assert context.message_id == "msg-012"
assert context.tenant_id == "tenant-345"
def test_create_minimal(self):
"""Test creating minimal ExecutionContext."""
context = ExecutionContext.create_minimal(user_id="user-123")
assert context.user_id == "user-123"
assert context.app_id is None
assert context.conversation_id is None
assert context.message_id is None
assert context.tenant_id is None
def test_to_dict(self):
"""Test converting ExecutionContext to dictionary."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
result = context.to_dict()
assert result == {
"user_id": "user-123",
"app_id": "app-456",
"conversation_id": "conv-789",
"message_id": "msg-012",
"tenant_id": "tenant-345",
}
def test_with_updates(self):
"""Test creating new context with updates."""
original = ExecutionContext(
user_id="user-123",
app_id="app-456",
)
updated = original.with_updates(message_id="msg-789")
# Original should be unchanged
assert original.message_id is None
# Updated should have new value
assert updated.message_id == "msg-789"
assert updated.user_id == "user-123"
assert updated.app_id == "app-456"
class TestAgentLog:
"""Tests for AgentLog entity."""
def test_create_log_with_required_fields(self):
"""Test creating AgentLog with required fields."""
log = AgentLog(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.id is not None # Auto-generated
assert log.parent_id is None
assert log.error is None
def test_log_type_enum(self):
"""Test LogType enum values."""
assert AgentLog.LogType.ROUND == "round"
assert AgentLog.LogType.THOUGHT == "thought"
assert AgentLog.LogType.TOOL_CALL == "tool_call"
def test_log_status_enum(self):
"""Test LogStatus enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_log_metadata_enum(self):
"""Test LogMetadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
class TestAgentScratchpadUnit:
"""Tests for AgentScratchpadUnit entity."""
def test_is_final_with_final_answer_action(self):
"""Test is_final returns True for Final Answer action."""
unit = AgentScratchpadUnit(
thought="I know the answer",
action=AgentScratchpadUnit.Action(
action_name="Final Answer",
action_input="The answer is 42",
),
)
assert unit.is_final() is True
def test_is_final_with_tool_action(self):
"""Test is_final returns False for tool action."""
unit = AgentScratchpadUnit(
thought="I need to search",
action=AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
),
)
assert unit.is_final() is False
def test_is_final_with_no_action(self):
"""Test is_final returns True when no action."""
unit = AgentScratchpadUnit(
thought="Just thinking",
)
assert unit.is_final() is True
def test_action_to_dict(self):
"""Test Action.to_dict method."""
action = AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
)
result = action.to_dict()
assert result == {
"action": "search",
"action_input": {"query": "test"},
}
class TestAgentEntity:
"""Tests for AgentEntity."""
def test_strategy_enum(self):
"""Test Strategy enum values."""
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
def test_create_with_prompt(self):
"""Test creating AgentEntity with prompt."""
prompt = AgentPromptEntity(
first_prompt="You are a helpful assistant.",
next_iteration="Continue thinking...",
)
entity = AgentEntity(
provider="openai",
model="gpt-4",
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
prompt=prompt,
max_iteration=5,
)
assert entity.provider == "openai"
assert entity.model == "gpt-4"
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
assert entity.prompt == prompt
assert entity.max_iteration == 5

View File

@ -0,0 +1,48 @@
from unittest.mock import MagicMock
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.workflow.graph_events import NodeRunStreamChunkEvent
from core.workflow.nodes import NodeType
class DummyQueueManager:
def __init__(self) -> None:
self.published = []
def publish(self, event, publish_from: PublishFrom) -> None:
self.published.append((event, publish_from))
def test_skip_empty_final_chunk() -> None:
queue_manager = DummyQueueManager()
runner = WorkflowBasedAppRunner(queue_manager=queue_manager, app_id="app")
empty_final_event = NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
selector=["node", "text"],
chunk="",
is_final=True,
)
runner._handle_event(workflow_entry=MagicMock(), event=empty_final_event)
assert queue_manager.published == []
normal_event = NodeRunStreamChunkEvent(
id="exec",
node_id="node",
node_type=NodeType.LLM,
selector=["node", "text"],
chunk="hi",
is_final=False,
)
runner._handle_event(workflow_entry=MagicMock(), event=normal_event)
assert len(queue_manager.published) == 1
published_event, publish_from = queue_manager.published[0]
assert publish_from == PublishFrom.APPLICATION_MANAGER
assert published_event.text == "hi"

View File

@ -0,0 +1,231 @@
"""Tests for ResponseStreamCoordinator object field streaming."""
from unittest.mock import MagicMock
from core.workflow.entities.tool_entities import ToolResultStatus
from core.workflow.enums import NodeType
from core.workflow.graph.graph import Graph
from core.workflow.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
from core.workflow.graph_engine.response_coordinator.session import ResponseSession
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
ToolCall,
ToolResult,
)
from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.base.template import Template, VariableSegment
from core.workflow.runtime import VariablePool
class TestResponseCoordinatorObjectStreaming:
"""Test streaming of object-type variables with child fields."""
def test_object_field_streaming(self):
"""Test that when selecting an object variable, all child field streams are forwarded."""
# Create mock graph and variable pool
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
# Mock nodes
llm_node = MagicMock()
llm_node.id = "llm_node"
llm_node.node_type = NodeType.LLM
llm_node.execution_type = MagicMock()
llm_node.blocks_variable_output = MagicMock(return_value=False)
response_node = MagicMock()
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER
response_node.execution_type = MagicMock()
response_node.blocks_variable_output = MagicMock(return_value=False)
# Mock template for response node
response_node.node_data = MagicMock(spec=BaseNodeData)
response_node.node_data.answer = "{{#llm_node.generation#}}"
graph.nodes = {
"llm_node": llm_node,
"response_node": response_node,
}
graph.root_node = llm_node
graph.get_outgoing_edges = MagicMock(return_value=[])
# Create coordinator
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Track execution
coordinator.track_node_execution("llm_node", "exec_123")
coordinator.track_node_execution("response_node", "exec_456")
# Simulate streaming events for child fields of generation object
# 1. Content stream
content_event_1 = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "content"],
chunk="Hello",
is_final=False,
chunk_type=ChunkType.TEXT,
)
content_event_2 = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "content"],
chunk=" world",
is_final=True,
chunk_type=ChunkType.TEXT,
)
# 2. Tool call stream
tool_call_event = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "tool_calls"],
chunk='{"query": "test"}',
is_final=True,
chunk_type=ChunkType.TOOL_CALL,
tool_call=ToolCall(
id="call_123",
name="search",
arguments='{"query": "test"}',
),
)
# 3. Tool result stream
tool_result_event = NodeRunStreamChunkEvent(
id="exec_123",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["llm_node", "generation", "tool_results"],
chunk="Found 10 results",
is_final=True,
chunk_type=ChunkType.TOOL_RESULT,
tool_result=ToolResult(
id="call_123",
name="search",
output="Found 10 results",
files=[],
status=ToolResultStatus.SUCCESS,
),
)
# Intercept these events
coordinator.intercept_event(content_event_1)
coordinator.intercept_event(tool_call_event)
coordinator.intercept_event(tool_result_event)
coordinator.intercept_event(content_event_2)
# Verify that all child streams are buffered
assert ("llm_node", "generation", "content") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_calls") in coordinator._stream_buffers
assert ("llm_node", "generation", "tool_results") in coordinator._stream_buffers
# Verify payloads are preserved in buffered events
buffered_call = coordinator._stream_buffers[("llm_node", "generation", "tool_calls")][0]
assert buffered_call.tool_call is not None
assert buffered_call.tool_call.id == "call_123"
buffered_result = coordinator._stream_buffers[("llm_node", "generation", "tool_results")][0]
assert buffered_result.tool_result is not None
assert buffered_result.tool_result.status == "success"
# Verify we can find child streams
child_streams = coordinator._find_child_streams(["llm_node", "generation"])
assert len(child_streams) == 3
assert ("llm_node", "generation", "content") in child_streams
assert ("llm_node", "generation", "tool_calls") in child_streams
assert ("llm_node", "generation", "tool_results") in child_streams
def test_find_child_streams(self):
"""Test the _find_child_streams method."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Add some mock streams
coordinator._stream_buffers = {
("node1", "generation", "content"): [],
("node1", "generation", "tool_calls"): [],
("node1", "generation", "thought"): [],
("node1", "text"): [], # Not a child of generation
("node2", "generation", "content"): [], # Different node
}
# Find children of node1.generation
children = coordinator._find_child_streams(["node1", "generation"])
assert len(children) == 3
assert ("node1", "generation", "content") in children
assert ("node1", "generation", "tool_calls") in children
assert ("node1", "generation", "thought") in children
assert ("node1", "text") not in children
assert ("node2", "generation", "content") not in children
def test_find_child_streams_with_closed_streams(self):
"""Test that _find_child_streams also considers closed streams."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
coordinator = ResponseStreamCoordinator(variable_pool, graph)
# Add some streams - some buffered, some closed
coordinator._stream_buffers = {
("node1", "generation", "content"): [],
}
coordinator._closed_streams = {
("node1", "generation", "tool_calls"),
("node1", "generation", "thought"),
}
# Should find all children regardless of whether they're in buffers or closed
children = coordinator._find_child_streams(["node1", "generation"])
assert len(children) == 3
assert ("node1", "generation", "content") in children
assert ("node1", "generation", "tool_calls") in children
assert ("node1", "generation", "thought") in children
def test_special_selector_rewrites_to_active_response_node(self):
"""Ensure special selectors attribute streams to the active response node."""
graph = MagicMock(spec=Graph)
variable_pool = MagicMock(spec=VariablePool)
response_node = MagicMock()
response_node.id = "response_node"
response_node.node_type = NodeType.ANSWER
graph.nodes = {"response_node": response_node}
graph.root_node = response_node
coordinator = ResponseStreamCoordinator(variable_pool, graph)
coordinator.track_node_execution("response_node", "exec_resp")
coordinator._active_session = ResponseSession(
node_id="response_node",
template=Template(segments=[VariableSegment(selector=["sys", "foo"])]),
)
event = NodeRunStreamChunkEvent(
id="stream_1",
node_id="llm_node",
node_type=NodeType.LLM,
selector=["sys", "foo"],
chunk="hi",
is_final=True,
chunk_type=ChunkType.TEXT,
)
coordinator._stream_buffers[("sys", "foo")] = [event]
coordinator._stream_positions[("sys", "foo")] = 0
coordinator._closed_streams.add(("sys", "foo"))
events, is_complete = coordinator._process_variable_segment(VariableSegment(selector=["sys", "foo"]))
assert is_complete
assert len(events) == 1
rewritten = events[0]
assert rewritten.node_id == "response_node"
assert rewritten.id == "exec_resp"

View File

@ -0,0 +1,328 @@
"""Tests for StreamChunkEvent and its subclasses."""
from core.workflow.entities import ToolCall, ToolResult, ToolResultStatus
from core.workflow.node_events import (
ChunkType,
StreamChunkEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
class TestChunkType:
"""Tests for ChunkType enum."""
def test_chunk_type_values(self):
"""Test that ChunkType has expected values."""
assert ChunkType.TEXT == "text"
assert ChunkType.TOOL_CALL == "tool_call"
assert ChunkType.TOOL_RESULT == "tool_result"
assert ChunkType.THOUGHT == "thought"
def test_chunk_type_is_str_enum(self):
"""Test that ChunkType values are strings."""
for chunk_type in ChunkType:
assert isinstance(chunk_type.value, str)
class TestStreamChunkEvent:
"""Tests for base StreamChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating StreamChunkEvent with required fields."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="Hello",
)
assert event.selector == ["node1", "text"]
assert event.chunk == "Hello"
assert event.is_final is False
assert event.chunk_type == ChunkType.TEXT
def test_create_with_all_fields(self):
"""Test creating StreamChunkEvent with all fields."""
event = StreamChunkEvent(
selector=["node1", "output"],
chunk="World",
is_final=True,
chunk_type=ChunkType.TEXT,
)
assert event.selector == ["node1", "output"]
assert event.chunk == "World"
assert event.is_final is True
assert event.chunk_type == ChunkType.TEXT
def test_default_chunk_type_is_text(self):
"""Test that default chunk_type is TEXT."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="test",
)
assert event.chunk_type == ChunkType.TEXT
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = StreamChunkEvent(
selector=["node1", "text"],
chunk="Hello",
is_final=True,
)
data = event.model_dump()
assert data["selector"] == ["node1", "text"]
assert data["chunk"] == "Hello"
assert data["is_final"] is True
assert data["chunk_type"] == "text"
class TestToolCallChunkEvent:
"""Tests for ToolCallChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ToolCallChunkEvent with required fields."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call=ToolCall(id="call_123", name="weather", arguments=None),
)
assert event.selector == ["node1", "tool_calls"]
assert event.chunk == '{"city": "Beijing"}'
assert event.tool_call.id == "call_123"
assert event.tool_call.name == "weather"
assert event.chunk_type == ChunkType.TOOL_CALL
def test_chunk_type_is_tool_call(self):
"""Test that chunk_type is always TOOL_CALL."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call=ToolCall(id="call_123", name="test_tool", arguments=None),
)
assert event.chunk_type == ChunkType.TOOL_CALL
def test_tool_arguments_field(self):
"""Test tool_arguments field."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"param": "value"}',
tool_call=ToolCall(
id="call_123",
name="test_tool",
arguments='{"param": "value"}',
),
)
assert event.tool_call.arguments == '{"param": "value"}'
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk='{"city": "Beijing"}',
tool_call=ToolCall(
id="call_123",
name="weather",
arguments='{"city": "Beijing"}',
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_call"
assert data["tool_call"]["id"] == "call_123"
assert data["tool_call"]["name"] == "weather"
assert data["tool_call"]["arguments"] == '{"city": "Beijing"}'
assert data["is_final"] is True
class TestToolResultChunkEvent:
"""Tests for ToolResultChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ToolResultChunkEvent with required fields."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny, 25°C",
tool_result=ToolResult(id="call_123", name="weather", output="Weather: Sunny, 25°C"),
)
assert event.selector == ["node1", "tool_results"]
assert event.chunk == "Weather: Sunny, 25°C"
assert event.tool_result.id == "call_123"
assert event.tool_result.name == "weather"
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_chunk_type_is_tool_result(self):
"""Test that chunk_type is always TOOL_RESULT."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.chunk_type == ChunkType.TOOL_RESULT
def test_tool_files_default_empty(self):
"""Test that tool_files defaults to empty list."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test_tool"),
)
assert event.tool_result.files == []
def test_tool_files_with_values(self):
"""Test tool_files with file IDs."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(
id="call_123",
name="test_tool",
files=["file_1", "file_2"],
),
)
assert event.tool_result.files == ["file_1", "file_2"]
def test_tool_error_output(self):
"""Test error output captured in tool_result."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="",
tool_result=ToolResult(
id="call_123",
name="test_tool",
output="Tool execution failed",
status=ToolResultStatus.ERROR,
),
)
assert event.tool_result.output == "Tool execution failed"
assert event.tool_result.status == ToolResultStatus.ERROR
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="Weather: Sunny",
tool_result=ToolResult(
id="call_123",
name="weather",
output="Weather: Sunny",
files=["file_1"],
status=ToolResultStatus.SUCCESS,
),
is_final=True,
)
data = event.model_dump()
assert data["chunk_type"] == "tool_result"
assert data["tool_result"]["id"] == "call_123"
assert data["tool_result"]["name"] == "weather"
assert data["tool_result"]["files"] == ["file_1"]
assert data["is_final"] is True
class TestThoughtChunkEvent:
"""Tests for ThoughtChunkEvent."""
def test_create_with_required_fields(self):
"""Test creating ThoughtChunkEvent with required fields."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="I need to query the weather...",
)
assert event.selector == ["node1", "thought"]
assert event.chunk == "I need to query the weather..."
assert event.chunk_type == ChunkType.THOUGHT
def test_chunk_type_is_thought(self):
"""Test that chunk_type is always THOUGHT."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="thinking...",
)
assert event.chunk_type == ChunkType.THOUGHT
def test_serialization(self):
"""Test that event can be serialized to dict."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="I need to analyze this...",
is_final=False,
)
data = event.model_dump()
assert data["chunk_type"] == "thought"
assert data["chunk"] == "I need to analyze this..."
assert data["is_final"] is False
class TestEventInheritance:
"""Tests for event inheritance relationships."""
def test_tool_call_is_stream_chunk(self):
"""Test that ToolCallChunkEvent is a StreamChunkEvent."""
event = ToolCallChunkEvent(
selector=["node1", "tool_calls"],
chunk="",
tool_call=ToolCall(id="call_123", name="test", arguments=None),
)
assert isinstance(event, StreamChunkEvent)
def test_tool_result_is_stream_chunk(self):
"""Test that ToolResultChunkEvent is a StreamChunkEvent."""
event = ToolResultChunkEvent(
selector=["node1", "tool_results"],
chunk="result",
tool_result=ToolResult(id="call_123", name="test"),
)
assert isinstance(event, StreamChunkEvent)
def test_thought_is_stream_chunk(self):
"""Test that ThoughtChunkEvent is a StreamChunkEvent."""
event = ThoughtChunkEvent(
selector=["node1", "thought"],
chunk="thinking...",
)
assert isinstance(event, StreamChunkEvent)
def test_all_events_have_common_fields(self):
"""Test that all events have common StreamChunkEvent fields."""
events = [
StreamChunkEvent(selector=["n", "t"], chunk="a"),
ToolCallChunkEvent(
selector=["n", "t"],
chunk="b",
tool_call=ToolCall(id="1", name="t", arguments=None),
),
ToolResultChunkEvent(
selector=["n", "t"],
chunk="c",
tool_result=ToolResult(id="1", name="t"),
),
ThoughtChunkEvent(selector=["n", "t"], chunk="d"),
]
for event in events:
assert hasattr(event, "selector")
assert hasattr(event, "chunk")
assert hasattr(event, "is_final")
assert hasattr(event, "chunk_type")

View File

@ -0,0 +1,149 @@
import types
from collections.abc import Generator
from typing import Any
import pytest
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities import ToolCallResult
from core.workflow.entities.tool_entities import ToolResultStatus
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeEventBase
from core.workflow.nodes.llm.node import LLMNode
class _StubModelInstance:
"""Minimal stub to satisfy _stream_llm_events signature."""
provider_model_bundle = None
def _drain(generator: Generator[NodeEventBase, None, Any]):
events: list = []
try:
while True:
events.append(next(generator))
except StopIteration as exc:
return events, exc.value
@pytest.fixture(autouse=True)
def patch_deduct_llm_quota(monkeypatch):
# Avoid touching real quota logic during unit tests
monkeypatch.setattr("core.workflow.nodes.llm.node.llm_utils.deduct_llm_quota", lambda **_: None)
def _make_llm_node(reasoning_format: str) -> LLMNode:
node = LLMNode.__new__(LLMNode)
object.__setattr__(node, "_node_data", types.SimpleNamespace(reasoning_format=reasoning_format, tools=[]))
object.__setattr__(node, "tenant_id", "tenant")
return node
def test_stream_llm_events_extracts_reasoning_for_tagged():
node = _make_llm_node(reasoning_format="tagged")
tagged_text = "<think>Thought</think>Answer"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=tagged_text,
usage=usage,
finish_reason="stop",
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
clean_text, reasoning_content, gen_reasoning, gen_clean, ret_usage, finish_reason, structured, gen_data = returned
assert clean_text == tagged_text # original preserved for output
assert reasoning_content == "" # tagged mode keeps reasoning separate
assert gen_clean == "Answer" # stripped content for generation
assert gen_reasoning == "Thought" # reasoning extracted from <think> tag
assert ret_usage == usage
assert finish_reason == "stop"
assert structured is None
assert gen_data is None
# generation building should include reasoning and sequence
generation_content = gen_clean or clean_text
sequence = [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len(generation_content)},
]
assert sequence == [
{"type": "reasoning", "index": 0},
{"type": "content", "start": 0, "end": len("Answer")},
]
def test_stream_llm_events_no_reasoning_results_in_empty_sequence():
node = _make_llm_node(reasoning_format="tagged")
plain_text = "Hello world"
usage = LLMUsage.empty_usage()
def generator():
yield ModelInvokeCompletedEvent(
text=plain_text,
usage=usage,
finish_reason=None,
reasoning_content="",
structured_output=None,
)
events, returned = _drain(
node._stream_llm_events(generator(), model_instance=types.SimpleNamespace(provider_model_bundle=None))
)
assert events == []
_, _, gen_reasoning, gen_clean, *_ = returned
generation_content = gen_clean or plain_text
assert gen_reasoning == ""
assert generation_content == plain_text
# Empty reasoning should imply empty sequence in generation construction
sequence = []
assert sequence == []
def test_serialize_tool_call_strips_files_to_ids():
file_cls = pytest.importorskip("core.file").File
file_type = pytest.importorskip("core.file.enums").FileType
transfer_method = pytest.importorskip("core.file.enums").FileTransferMethod
file_with_id = file_cls(
id="f1",
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
remote_url="http://example.com/f1",
storage_key="k1",
)
file_with_related = file_cls(
id=None,
tenant_id="t",
type=file_type.IMAGE,
transfer_method=transfer_method.REMOTE_URL,
related_id="rel2",
remote_url="http://example.com/f2",
storage_key="k2",
)
tool_call = ToolCallResult(
id="tc",
name="do",
arguments='{"a":1}',
output="ok",
files=[file_with_id, file_with_related],
status=ToolResultStatus.SUCCESS,
)
serialized = LLMNode._serialize_tool_call(tool_call)
assert serialized["files"] == ["f1", "rel2"]
assert serialized["id"] == "tc"
assert serialized["name"] == "do"
assert serialized["arguments"] == '{"a":1}'
assert serialized["output"] == "ok"