mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 09:58:04 +08:00
feat: add agent package
This commit is contained in:
324
api/tests/unit_tests/core/agent/patterns/test_base.py
Normal file
324
api/tests/unit_tests/core/agent/patterns/test_base.py
Normal file
@ -0,0 +1,324 @@
|
||||
"""Tests for AgentPattern base class."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.agent.patterns.base import AgentPattern
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class ConcreteAgentPattern(AgentPattern):
|
||||
"""Concrete implementation of AgentPattern for testing."""
|
||||
|
||||
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
|
||||
"""Minimal implementation for testing."""
|
||||
yield from []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_pattern(mock_model_instance, mock_context):
|
||||
"""Create a concrete agent pattern for testing."""
|
||||
return ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
|
||||
class TestAccumulateUsage:
|
||||
"""Tests for _accumulate_usage method."""
|
||||
|
||||
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
|
||||
"""Test accumulating usage to an empty dict creates a copy."""
|
||||
total_usage: dict = {"usage": None}
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"] is not None
|
||||
assert total_usage["usage"].total_tokens == 150
|
||||
assert total_usage["usage"].prompt_tokens == 100
|
||||
assert total_usage["usage"].completion_tokens == 50
|
||||
# Verify it's a copy, not a reference
|
||||
assert total_usage["usage"] is not delta_usage
|
||||
|
||||
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
|
||||
"""Test accumulating usage adds to existing values."""
|
||||
initial_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
total_usage: dict = {"usage": initial_usage}
|
||||
|
||||
delta_usage = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
agent_pattern._accumulate_usage(total_usage, delta_usage)
|
||||
|
||||
assert total_usage["usage"].total_tokens == 450 # 150 + 300
|
||||
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
|
||||
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
|
||||
|
||||
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
|
||||
"""Test accumulating usage across multiple rounds."""
|
||||
total_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1: 100 tokens
|
||||
round1_usage = LLMUsage(
|
||||
prompt_tokens=70,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.07"),
|
||||
completion_tokens=30,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.06"),
|
||||
total_tokens=100,
|
||||
total_price=Decimal("0.13"),
|
||||
currency="USD",
|
||||
latency=0.3,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round1_usage)
|
||||
assert total_usage["usage"].total_tokens == 100
|
||||
|
||||
# Round 2: 150 tokens
|
||||
round2_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.4,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round2_usage)
|
||||
assert total_usage["usage"].total_tokens == 250 # 100 + 150
|
||||
|
||||
# Round 3: 200 tokens
|
||||
round3_usage = LLMUsage(
|
||||
prompt_tokens=130,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.13"),
|
||||
completion_tokens=70,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.14"),
|
||||
total_tokens=200,
|
||||
total_price=Decimal("0.27"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
agent_pattern._accumulate_usage(total_usage, round3_usage)
|
||||
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
|
||||
|
||||
|
||||
class TestCreateLog:
|
||||
"""Tests for _create_log method."""
|
||||
|
||||
def test_create_log_with_label_and_status(self, agent_pattern):
|
||||
"""Test creating a log with label and status."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.parent_id is None
|
||||
|
||||
def test_create_log_with_parent_id(self, agent_pattern):
|
||||
"""Test creating a log with parent_id."""
|
||||
parent_log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
child_log = agent_pattern._create_log(
|
||||
label="CALL tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=parent_log.id,
|
||||
)
|
||||
|
||||
assert child_log.parent_id == parent_log.id
|
||||
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
|
||||
|
||||
class TestFinishLog:
|
||||
"""Tests for _finish_log method."""
|
||||
|
||||
def test_finish_log_updates_status(self, agent_pattern):
|
||||
"""Test that finish_log updates status to SUCCESS."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
|
||||
|
||||
assert finished_log.status == AgentLog.LogStatus.SUCCESS
|
||||
assert finished_log.data == {"result": "done"}
|
||||
|
||||
def test_finish_log_adds_usage_metadata(self, agent_pattern):
|
||||
"""Test that finish_log adds usage to metadata."""
|
||||
log = agent_pattern._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
finished_log = agent_pattern._finish_log(log, usage=usage)
|
||||
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
|
||||
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
|
||||
|
||||
|
||||
class TestFindToolByName:
|
||||
"""Tests for _find_tool_by_name method."""
|
||||
|
||||
def test_find_existing_tool(self, mock_model_instance, mock_context):
|
||||
"""Test finding an existing tool by name."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("test_tool")
|
||||
assert found_tool == mock_tool
|
||||
|
||||
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
|
||||
"""Test that finding a nonexistent tool returns None."""
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.entity.identity.name = "test_tool"
|
||||
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
found_tool = pattern._find_tool_by_name("nonexistent_tool")
|
||||
assert found_tool is None
|
||||
|
||||
|
||||
class TestMaxIterationsCapping:
|
||||
"""Tests for max_iterations capping."""
|
||||
|
||||
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is capped at 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=150,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 99
|
||||
|
||||
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is not capped when under 99."""
|
||||
pattern = ConcreteAgentPattern(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
max_iterations=50,
|
||||
)
|
||||
|
||||
assert pattern.max_iterations == 50
|
||||
332
api/tests/unit_tests/core/agent/patterns/test_function_call.py
Normal file
332
api/tests/unit_tests/core/agent/patterns/test_function_call.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Tests for FunctionCallStrategy."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentLog, ExecutionContext
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.to_prompt_message_tool.return_value = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"param1": {"type": "string", "description": "A parameter"}},
|
||||
"required": ["param1"],
|
||||
},
|
||||
)
|
||||
return tool
|
||||
|
||||
|
||||
class TestFunctionCallStrategyInit:
|
||||
"""Tests for FunctionCallStrategy initialization."""
|
||||
|
||||
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test basic initialization."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=10,
|
||||
)
|
||||
|
||||
assert strategy.model_instance == mock_model_instance
|
||||
assert strategy.context == mock_context
|
||||
assert strategy.max_iterations == 10
|
||||
assert len(strategy.tools) == 1
|
||||
|
||||
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test initialization with tool_invoke_hook."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
|
||||
|
||||
class TestConvertToolsToPromptFormat:
|
||||
"""Tests for _convert_tools_to_prompt_format method."""
|
||||
|
||||
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert len(tools) == 1
|
||||
assert isinstance(tools[0], PromptMessageTool)
|
||||
assert tools[0].name == "test_tool"
|
||||
|
||||
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
|
||||
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
tools = strategy._convert_tools_to_prompt_format()
|
||||
|
||||
assert tools == []
|
||||
|
||||
|
||||
class TestAgentLogGeneration:
|
||||
"""Tests for AgentLog generation during run."""
|
||||
|
||||
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that round logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
max_iterations=1,
|
||||
)
|
||||
|
||||
# Create a round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"inputs": {"query": "test"}},
|
||||
)
|
||||
|
||||
assert round_log.label == "ROUND 1"
|
||||
assert round_log.log_type == AgentLog.LogType.ROUND
|
||||
assert round_log.status == AgentLog.LogStatus.START
|
||||
assert "inputs" in round_log.data
|
||||
|
||||
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool call logs have correct structure."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Create a parent round log
|
||||
round_log = strategy._create_log(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
|
||||
# Create a tool call log
|
||||
tool_log = strategy._create_log(
|
||||
label="CALL test_tool",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
|
||||
parent_id=round_log.id,
|
||||
)
|
||||
|
||||
assert tool_log.label == "CALL test_tool"
|
||||
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
|
||||
assert tool_log.parent_id == round_log.id
|
||||
assert tool_log.data["tool_name"] == "test_tool"
|
||||
|
||||
|
||||
class TestToolInvocation:
|
||||
"""Tests for tool invocation."""
|
||||
|
||||
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool invocation uses hook when provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_hook = MagicMock()
|
||||
mock_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
|
||||
)
|
||||
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
|
||||
|
||||
mock_hook.assert_called_once()
|
||||
assert result == "Tool result"
|
||||
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
|
||||
assert meta == mock_meta
|
||||
|
||||
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that tool_invoke_hook is None when not provided."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
tool_invoke_hook=None,
|
||||
)
|
||||
|
||||
# Verify that tool_invoke_hook is None
|
||||
assert strategy.tool_invoke_hook is None
|
||||
|
||||
|
||||
class TestUsageTracking:
|
||||
"""Tests for usage tracking across rounds."""
|
||||
|
||||
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
|
||||
"""Test that round usage is tracked separately from total."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
# Simulate two rounds of usage
|
||||
total_usage: dict = {"usage": None}
|
||||
round1_usage: dict = {"usage": None}
|
||||
round2_usage: dict = {"usage": None}
|
||||
|
||||
# Round 1
|
||||
usage1 = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round1_usage, usage1)
|
||||
strategy._accumulate_usage(total_usage, usage1)
|
||||
|
||||
# Round 2
|
||||
usage2 = LLMUsage(
|
||||
prompt_tokens=200,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.2"),
|
||||
completion_tokens=100,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.2"),
|
||||
total_tokens=300,
|
||||
total_price=Decimal("0.4"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
strategy._accumulate_usage(round2_usage, usage2)
|
||||
strategy._accumulate_usage(total_usage, usage2)
|
||||
|
||||
# Verify round usage is separate
|
||||
assert round1_usage["usage"].total_tokens == 150
|
||||
assert round2_usage["usage"].total_tokens == 300
|
||||
# Verify total is accumulated
|
||||
assert total_usage["usage"].total_tokens == 450
|
||||
|
||||
|
||||
class TestPromptMessageHandling:
|
||||
"""Tests for prompt message handling."""
|
||||
|
||||
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that messages include system and user prompts."""
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
|
||||
strategy = FunctionCallStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
# Just verify the messages can be processed
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], SystemPromptMessage)
|
||||
assert isinstance(messages[1], UserPromptMessage)
|
||||
|
||||
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that assistant messages can contain tool calls."""
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
tool_call = AssistantPromptMessage.ToolCall(
|
||||
id="call_123",
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name="test_tool",
|
||||
arguments='{"param1": "value1"}',
|
||||
),
|
||||
)
|
||||
|
||||
assistant_message = AssistantPromptMessage(
|
||||
content="I'll help you with that.",
|
||||
tool_calls=[tool_call],
|
||||
)
|
||||
|
||||
assert len(assistant_message.tool_calls) == 1
|
||||
assert assistant_message.tool_calls[0].function.name == "test_tool"
|
||||
224
api/tests/unit_tests/core/agent/patterns/test_react.py
Normal file
224
api/tests/unit_tests/core/agent/patterns/test_react.py
Normal file
@ -0,0 +1,224 @@
|
||||
"""Tests for ReActStrategy."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import ExecutionContext
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool():
|
||||
"""Create a mock tool."""
|
||||
from core.model_runtime.entities.message_entities import PromptMessageTool
|
||||
|
||||
tool = MagicMock()
|
||||
tool.entity.identity.name = "test_tool"
|
||||
tool.entity.identity.provider = "test_provider"
|
||||
|
||||
# Use real PromptMessageTool for proper serialization
|
||||
prompt_tool = PromptMessageTool(
|
||||
name="test_tool",
|
||||
description="A test tool",
|
||||
parameters={"type": "object", "properties": {}},
|
||||
)
|
||||
tool.to_prompt_message_tool.return_value = prompt_tool
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
class TestReActStrategyInit:
|
||||
"""Tests for ReActStrategy initialization."""
|
||||
|
||||
def test_init_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that instruction is stored correctly."""
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that empty instruction is handled correctly."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
assert strategy.instruction == ""
|
||||
|
||||
|
||||
class TestBuildPromptWithReactFormat:
|
||||
"""Tests for _build_prompt_with_react_format method."""
|
||||
|
||||
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tools}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
UserPromptMessage(content="Hello"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# The tools placeholder should be replaced with JSON
|
||||
assert "{{tools}}" not in result[0].content
|
||||
assert "test_tool" in result[0].content
|
||||
|
||||
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
|
||||
"""Test that {{tool_names}} placeholder is replaced."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[mock_tool],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "Valid actions: {{tool_names}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
assert "{{tool_names}}" not in result[0].content
|
||||
assert '"test_tool"' in result[0].content
|
||||
|
||||
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
|
||||
"""Test that {{instruction}} placeholder is replaced."""
|
||||
instruction = "You are a helpful coding assistant."
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
|
||||
|
||||
assert "{{instruction}}" not in result[0].content
|
||||
assert instruction in result[0].content
|
||||
|
||||
def test_no_tools_available_message(self, mock_model_instance, mock_context):
|
||||
"""Test that 'No tools available' is shown when include_tools is False."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
system_content = "You have access to: {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=system_content),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], False)
|
||||
|
||||
assert "No tools available" in result[0].content
|
||||
|
||||
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
|
||||
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.model_runtime.entities import AssistantPromptMessage
|
||||
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
scratchpad = [
|
||||
AgentScratchpadUnit(
|
||||
thought="I need to search for information",
|
||||
action_str='{"action": "search", "action_input": "query"}',
|
||||
observation="Search results here",
|
||||
)
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
|
||||
|
||||
# The last message should be an AssistantPromptMessage with scratchpad content
|
||||
assert len(result) == 3
|
||||
assert isinstance(result[-1], AssistantPromptMessage)
|
||||
assert "I need to search for information" in result[-1].content
|
||||
assert "Search results here" in result[-1].content
|
||||
|
||||
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
|
||||
"""Test that empty scratchpad doesn't add extra message."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="System prompt"),
|
||||
UserPromptMessage(content="User query"),
|
||||
]
|
||||
|
||||
result = strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Should only have the original 2 messages
|
||||
assert len(result) == 2
|
||||
|
||||
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
|
||||
"""Test that original messages list is not modified."""
|
||||
strategy = ReActStrategy(
|
||||
model_instance=mock_model_instance,
|
||||
tools=[],
|
||||
context=mock_context,
|
||||
)
|
||||
|
||||
original_content = "Original system prompt {{tools}}"
|
||||
messages = [
|
||||
SystemPromptMessage(content=original_content),
|
||||
]
|
||||
|
||||
strategy._build_prompt_with_react_format(messages, [], True)
|
||||
|
||||
# Original message should not be modified
|
||||
assert messages[0].content == original_content
|
||||
@ -0,0 +1,203 @@
|
||||
"""Tests for StrategyFactory."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.agent.patterns.function_call import FunctionCallStrategy
|
||||
from core.agent.patterns.react import ReActStrategy
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_model_instance():
|
||||
"""Create a mock model instance."""
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.provider = "test-provider"
|
||||
return model_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
"""Create a mock execution context."""
|
||||
return ExecutionContext(
|
||||
user_id="test-user",
|
||||
app_id="test-app",
|
||||
conversation_id="test-conversation",
|
||||
message_id="test-message",
|
||||
tenant_id="test-tenant",
|
||||
)
|
||||
|
||||
|
||||
class TestStrategyFactory:
|
||||
"""Tests for StrategyFactory.create_strategy method."""
|
||||
|
||||
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
|
||||
model_features = [ModelFeature.MULTI_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
|
||||
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
|
||||
model_features = [ModelFeature.STREAM_TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model doesn't support tool calling."""
|
||||
model_features = [ModelFeature.VISION] # Only vision, no tool calling
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy is created when model has no features."""
|
||||
model_features: list[ModelFeature] = []
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
|
||||
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, FunctionCallStrategy)
|
||||
|
||||
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
|
||||
self, mock_model_instance, mock_context
|
||||
):
|
||||
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
|
||||
model_features: list[ModelFeature] = [] # No tool calling support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
)
|
||||
|
||||
# Should fall back to ReAct since FC is not supported
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
|
||||
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
|
||||
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
|
||||
"""Test that ReActStrategy receives instruction parameter."""
|
||||
model_features: list[ModelFeature] = []
|
||||
instruction = "You are a helpful assistant."
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
assert isinstance(strategy, ReActStrategy)
|
||||
assert strategy.instruction == instruction
|
||||
|
||||
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that max_iterations is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
max_iterations = 5
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
assert strategy.max_iterations == max_iterations
|
||||
|
||||
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
|
||||
"""Test that tool_invoke_hook is passed to the strategy."""
|
||||
model_features = [ModelFeature.TOOL_CALL]
|
||||
mock_hook = MagicMock()
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=mock_model_instance,
|
||||
context=mock_context,
|
||||
tools=[],
|
||||
files=[],
|
||||
tool_invoke_hook=mock_hook,
|
||||
)
|
||||
|
||||
assert strategy.tool_invoke_hook == mock_hook
|
||||
388
api/tests/unit_tests/core/agent/test_agent_app_runner.py
Normal file
388
api/tests/unit_tests/core/agent/test_agent_app_runner.py
Normal file
@ -0,0 +1,388 @@
|
||||
"""Tests for AgentAppRunner."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
|
||||
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
"""Tests for _organize_prompt_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
# We'll patch the class to avoid complex initialization
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
# Set up required attributes
|
||||
runner.config = MagicMock(spec=AgentEntity)
|
||||
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
runner.config.prompt = None
|
||||
|
||||
runner.app_config = MagicMock()
|
||||
runner.app_config.prompt_template = MagicMock()
|
||||
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
runner.query = "Hello"
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.model_config = MagicMock()
|
||||
runner.memory = None
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
|
||||
return runner
|
||||
|
||||
def test_function_calling_uses_simple_prompt(self, mock_runner):
|
||||
"""Test that function calling strategy uses simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
|
||||
"""Test that chain of thought strategy uses agent prompt template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = AgentPromptEntity(
|
||||
first_prompt="ReAct prompt template with {{tools}}",
|
||||
next_iteration="Continue...",
|
||||
)
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with agent prompt
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "ReAct prompt template with {{tools}}"
|
||||
|
||||
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
|
||||
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
|
||||
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
mock_runner.config.prompt = None
|
||||
|
||||
with patch.object(mock_runner, "_init_system_message") as mock_init:
|
||||
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
|
||||
with patch.object(mock_runner, "_organize_user_query") as mock_query:
|
||||
mock_query.return_value = [UserPromptMessage(content="Hello")]
|
||||
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
|
||||
mock_transform.return_value.get_prompt.return_value = [
|
||||
SystemPromptMessage(content="You are a helpful assistant.")
|
||||
]
|
||||
|
||||
result = mock_runner._organize_prompt_messages()
|
||||
|
||||
# Verify _init_system_message was called with simple_prompt_template
|
||||
mock_init.assert_called_once()
|
||||
call_args = mock_init.call_args[0]
|
||||
assert call_args[0] == "You are a helpful assistant."
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
"""Tests for _init_system_message method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_empty_messages_with_template(self, mock_runner):
|
||||
"""Test that system message is created when messages are empty."""
|
||||
result = mock_runner._init_system_message("System template", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
def test_empty_messages_without_template(self, mock_runner):
|
||||
"""Test that empty list is returned when no template and no messages."""
|
||||
result = mock_runner._init_system_message("", [])
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_existing_system_message_not_duplicated(self, mock_runner):
|
||||
"""Test that system message is not duplicated if already present."""
|
||||
existing_messages = [
|
||||
SystemPromptMessage(content="Existing system"),
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("New template", existing_messages)
|
||||
|
||||
# Should not insert new system message
|
||||
assert len(result) == 2
|
||||
assert result[0].content == "Existing system"
|
||||
|
||||
def test_system_message_inserted_when_missing(self, mock_runner):
|
||||
"""Test that system message is inserted when first message is not system."""
|
||||
existing_messages = [
|
||||
UserPromptMessage(content="User message"),
|
||||
]
|
||||
|
||||
result = mock_runner._init_system_message("System template", existing_messages)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], SystemPromptMessage)
|
||||
assert result[0].content == "System template"
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
"""Tests for _clear_user_prompt_image_messages method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
return runner
|
||||
|
||||
def test_text_content_unchanged(self, mock_runner):
|
||||
"""Test that text content is unchanged."""
|
||||
messages = [
|
||||
UserPromptMessage(content="Plain text message"),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Plain text message"
|
||||
|
||||
def test_original_messages_not_modified(self, mock_runner):
|
||||
"""Test that original messages are not modified (deep copy)."""
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
messages = [
|
||||
UserPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data="Text part"),
|
||||
ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
),
|
||||
]
|
||||
),
|
||||
]
|
||||
|
||||
result = mock_runner._clear_user_prompt_image_messages(messages)
|
||||
|
||||
# Original should still have list content
|
||||
assert isinstance(messages[0].content, list)
|
||||
# Result should have string content
|
||||
assert isinstance(result[0].content, str)
|
||||
|
||||
|
||||
class TestToolInvokeHook:
|
||||
"""Tests for _create_tool_invoke_hook method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
|
||||
runner.user_id = "test-user"
|
||||
runner.tenant_id = "test-tenant"
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.trace_manager = None
|
||||
runner.application_generate_entity.invoke_from = "api"
|
||||
runner.application_generate_entity.app_config = MagicMock()
|
||||
runner.application_generate_entity.app_config.app_id = "test-app"
|
||||
runner.agent_callback = MagicMock()
|
||||
runner.conversation = MagicMock()
|
||||
runner.conversation.id = "test-conversation"
|
||||
runner.queue_manager = MagicMock()
|
||||
runner._current_message_file_ids = []
|
||||
|
||||
return runner
|
||||
|
||||
def test_hook_calls_agent_invoke(self, mock_runner):
|
||||
"""Test that the hook calls ToolEngine.agent_invoke."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={
|
||||
"tool_provider_type": "test_provider",
|
||||
"tool_provider": "test_id",
|
||||
},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
|
||||
|
||||
# Verify ToolEngine.agent_invoke was called
|
||||
mock_engine.agent_invoke.assert_called_once()
|
||||
|
||||
# Verify return values
|
||||
assert result_content == "Tool result"
|
||||
assert result_files == ["file-1", "file-2"]
|
||||
assert result_meta == mock_tool_meta
|
||||
|
||||
def test_hook_publishes_file_events(self, mock_runner):
|
||||
"""Test that the hook publishes QueueMessageFileEvent for files."""
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = "test-message"
|
||||
|
||||
mock_tool = MagicMock()
|
||||
mock_tool_meta = ToolInvokeMeta(
|
||||
time_cost=0.5,
|
||||
error=None,
|
||||
tool_config={},
|
||||
)
|
||||
|
||||
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
|
||||
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
|
||||
|
||||
hook = mock_runner._create_tool_invoke_hook(mock_message)
|
||||
hook(mock_tool, {}, "test_tool")
|
||||
|
||||
# Verify file events were published
|
||||
assert mock_runner.queue_manager.publish.call_count == 2
|
||||
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
|
||||
|
||||
|
||||
class TestAgentLogProcessing:
|
||||
"""Tests for AgentLog processing in run method."""
|
||||
|
||||
def test_agent_log_status_enum(self):
|
||||
"""Test AgentLog status enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_agent_log_metadata_enum(self):
|
||||
"""Test AgentLog metadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
def test_agent_result_structure(self):
|
||||
"""Test AgentResult structure."""
|
||||
usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal("0.001"),
|
||||
prompt_price=Decimal("0.1"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal("0.001"),
|
||||
completion_price=Decimal("0.1"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.2"),
|
||||
currency="USD",
|
||||
latency=0.5,
|
||||
)
|
||||
|
||||
result = AgentResult(
|
||||
text="Final answer",
|
||||
files=[],
|
||||
usage=usage,
|
||||
finish_reason="stop",
|
||||
)
|
||||
|
||||
assert result.text == "Final answer"
|
||||
assert result.files == []
|
||||
assert result.usage == usage
|
||||
assert result.finish_reason == "stop"
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
"""Tests for _organize_user_query method."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runner(self):
|
||||
"""Create a mock AgentAppRunner for testing."""
|
||||
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
|
||||
runner = AgentAppRunner.__new__(AgentAppRunner)
|
||||
runner.files = []
|
||||
runner.application_generate_entity = MagicMock()
|
||||
runner.application_generate_entity.file_upload_config = None
|
||||
return runner
|
||||
|
||||
def test_simple_query_without_files(self, mock_runner):
|
||||
"""Test organizing a simple query without files."""
|
||||
result = mock_runner._organize_user_query("Hello world", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert result[0].content == "Hello world"
|
||||
|
||||
def test_query_with_files(self, mock_runner):
|
||||
"""Test organizing a query with files."""
|
||||
from core.file.models import File
|
||||
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_runner.files = [mock_file]
|
||||
|
||||
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
|
||||
data="http://example.com/image.jpg",
|
||||
format="url",
|
||||
mime_type="image/jpeg",
|
||||
)
|
||||
|
||||
result = mock_runner._organize_user_query("Describe this image", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
assert isinstance(result[0].content, list)
|
||||
assert len(result[0].content) == 2 # Image + Text
|
||||
191
api/tests/unit_tests/core/agent/test_entities.py
Normal file
191
api/tests/unit_tests/core/agent/test_entities.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""Tests for agent entities."""
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Tests for ExecutionContext entity."""
|
||||
|
||||
def test_create_with_all_fields(self):
|
||||
"""Test creating ExecutionContext with all fields."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id == "app-456"
|
||||
assert context.conversation_id == "conv-789"
|
||||
assert context.message_id == "msg-012"
|
||||
assert context.tenant_id == "tenant-345"
|
||||
|
||||
def test_create_minimal(self):
|
||||
"""Test creating minimal ExecutionContext."""
|
||||
context = ExecutionContext.create_minimal(user_id="user-123")
|
||||
|
||||
assert context.user_id == "user-123"
|
||||
assert context.app_id is None
|
||||
assert context.conversation_id is None
|
||||
assert context.message_id is None
|
||||
assert context.tenant_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
"""Test converting ExecutionContext to dictionary."""
|
||||
context = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
conversation_id="conv-789",
|
||||
message_id="msg-012",
|
||||
tenant_id="tenant-345",
|
||||
)
|
||||
|
||||
result = context.to_dict()
|
||||
|
||||
assert result == {
|
||||
"user_id": "user-123",
|
||||
"app_id": "app-456",
|
||||
"conversation_id": "conv-789",
|
||||
"message_id": "msg-012",
|
||||
"tenant_id": "tenant-345",
|
||||
}
|
||||
|
||||
def test_with_updates(self):
|
||||
"""Test creating new context with updates."""
|
||||
original = ExecutionContext(
|
||||
user_id="user-123",
|
||||
app_id="app-456",
|
||||
)
|
||||
|
||||
updated = original.with_updates(message_id="msg-789")
|
||||
|
||||
# Original should be unchanged
|
||||
assert original.message_id is None
|
||||
# Updated should have new value
|
||||
assert updated.message_id == "msg-789"
|
||||
assert updated.user_id == "user-123"
|
||||
assert updated.app_id == "app-456"
|
||||
|
||||
|
||||
class TestAgentLog:
|
||||
"""Tests for AgentLog entity."""
|
||||
|
||||
def test_create_log_with_required_fields(self):
|
||||
"""Test creating AgentLog with required fields."""
|
||||
log = AgentLog(
|
||||
label="ROUND 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
|
||||
assert log.label == "ROUND 1"
|
||||
assert log.log_type == AgentLog.LogType.ROUND
|
||||
assert log.status == AgentLog.LogStatus.START
|
||||
assert log.data == {"key": "value"}
|
||||
assert log.id is not None # Auto-generated
|
||||
assert log.parent_id is None
|
||||
assert log.error is None
|
||||
|
||||
def test_log_type_enum(self):
|
||||
"""Test LogType enum values."""
|
||||
assert AgentLog.LogType.ROUND == "round"
|
||||
assert AgentLog.LogType.THOUGHT == "thought"
|
||||
assert AgentLog.LogType.TOOL_CALL == "tool_call"
|
||||
|
||||
def test_log_status_enum(self):
|
||||
"""Test LogStatus enum values."""
|
||||
assert AgentLog.LogStatus.START == "start"
|
||||
assert AgentLog.LogStatus.SUCCESS == "success"
|
||||
assert AgentLog.LogStatus.ERROR == "error"
|
||||
|
||||
def test_log_metadata_enum(self):
|
||||
"""Test LogMetadata enum values."""
|
||||
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
|
||||
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
|
||||
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
|
||||
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
|
||||
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
|
||||
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
|
||||
|
||||
|
||||
class TestAgentScratchpadUnit:
|
||||
"""Tests for AgentScratchpadUnit entity."""
|
||||
|
||||
def test_is_final_with_final_answer_action(self):
|
||||
"""Test is_final returns True for Final Answer action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I know the answer",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="Final Answer",
|
||||
action_input="The answer is 42",
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_is_final_with_tool_action(self):
|
||||
"""Test is_final returns False for tool action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="I need to search",
|
||||
action=AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
),
|
||||
)
|
||||
|
||||
assert unit.is_final() is False
|
||||
|
||||
def test_is_final_with_no_action(self):
|
||||
"""Test is_final returns True when no action."""
|
||||
unit = AgentScratchpadUnit(
|
||||
thought="Just thinking",
|
||||
)
|
||||
|
||||
assert unit.is_final() is True
|
||||
|
||||
def test_action_to_dict(self):
|
||||
"""Test Action.to_dict method."""
|
||||
action = AgentScratchpadUnit.Action(
|
||||
action_name="search",
|
||||
action_input={"query": "test"},
|
||||
)
|
||||
|
||||
result = action.to_dict()
|
||||
|
||||
assert result == {
|
||||
"action": "search",
|
||||
"action_input": {"query": "test"},
|
||||
}
|
||||
|
||||
|
||||
class TestAgentEntity:
|
||||
"""Tests for AgentEntity."""
|
||||
|
||||
def test_strategy_enum(self):
|
||||
"""Test Strategy enum values."""
|
||||
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
|
||||
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
|
||||
|
||||
def test_create_with_prompt(self):
|
||||
"""Test creating AgentEntity with prompt."""
|
||||
prompt = AgentPromptEntity(
|
||||
first_prompt="You are a helpful assistant.",
|
||||
next_iteration="Continue thinking...",
|
||||
)
|
||||
|
||||
entity = AgentEntity(
|
||||
provider="openai",
|
||||
model="gpt-4",
|
||||
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
prompt=prompt,
|
||||
max_iteration=5,
|
||||
)
|
||||
|
||||
assert entity.provider == "openai"
|
||||
assert entity.model == "gpt-4"
|
||||
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
assert entity.prompt == prompt
|
||||
assert entity.max_iteration == 5
|
||||
Reference in New Issue
Block a user