feat: add agent package

This commit is contained in:
Novice
2025-12-09 11:26:02 +08:00
parent 15fec024c0
commit 2b23c43434
71 changed files with 5945 additions and 1213 deletions

View File

@ -0,0 +1,324 @@
"""Tests for AgentPattern base class."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.agent.patterns.base import AgentPattern
from core.model_runtime.entities.llm_entities import LLMUsage
class ConcreteAgentPattern(AgentPattern):
"""Concrete implementation of AgentPattern for testing."""
def run(self, prompt_messages, model_parameters, stop=[], stream=True):
"""Minimal implementation for testing."""
yield from []
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def agent_pattern(mock_model_instance, mock_context):
"""Create a concrete agent pattern for testing."""
return ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=10,
)
class TestAccumulateUsage:
"""Tests for _accumulate_usage method."""
def test_accumulate_usage_to_empty_dict(self, agent_pattern):
"""Test accumulating usage to an empty dict creates a copy."""
total_usage: dict = {"usage": None}
delta_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"] is not None
assert total_usage["usage"].total_tokens == 150
assert total_usage["usage"].prompt_tokens == 100
assert total_usage["usage"].completion_tokens == 50
# Verify it's a copy, not a reference
assert total_usage["usage"] is not delta_usage
def test_accumulate_usage_adds_to_existing(self, agent_pattern):
"""Test accumulating usage adds to existing values."""
initial_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
total_usage: dict = {"usage": initial_usage}
delta_usage = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, delta_usage)
assert total_usage["usage"].total_tokens == 450 # 150 + 300
assert total_usage["usage"].prompt_tokens == 300 # 100 + 200
assert total_usage["usage"].completion_tokens == 150 # 50 + 100
def test_accumulate_usage_multiple_rounds(self, agent_pattern):
"""Test accumulating usage across multiple rounds."""
total_usage: dict = {"usage": None}
# Round 1: 100 tokens
round1_usage = LLMUsage(
prompt_tokens=70,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.07"),
completion_tokens=30,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.06"),
total_tokens=100,
total_price=Decimal("0.13"),
currency="USD",
latency=0.3,
)
agent_pattern._accumulate_usage(total_usage, round1_usage)
assert total_usage["usage"].total_tokens == 100
# Round 2: 150 tokens
round2_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.4,
)
agent_pattern._accumulate_usage(total_usage, round2_usage)
assert total_usage["usage"].total_tokens == 250 # 100 + 150
# Round 3: 200 tokens
round3_usage = LLMUsage(
prompt_tokens=130,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.13"),
completion_tokens=70,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.14"),
total_tokens=200,
total_price=Decimal("0.27"),
currency="USD",
latency=0.5,
)
agent_pattern._accumulate_usage(total_usage, round3_usage)
assert total_usage["usage"].total_tokens == 450 # 100 + 150 + 200
class TestCreateLog:
"""Tests for _create_log method."""
def test_create_log_with_label_and_status(self, agent_pattern):
"""Test creating a log with label and status."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.parent_id is None
def test_create_log_with_parent_id(self, agent_pattern):
"""Test creating a log with parent_id."""
parent_log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
child_log = agent_pattern._create_log(
label="CALL tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={},
parent_id=parent_log.id,
)
assert child_log.parent_id == parent_log.id
assert child_log.log_type == AgentLog.LogType.TOOL_CALL
class TestFinishLog:
"""Tests for _finish_log method."""
def test_finish_log_updates_status(self, agent_pattern):
"""Test that finish_log updates status to SUCCESS."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
finished_log = agent_pattern._finish_log(log, data={"result": "done"})
assert finished_log.status == AgentLog.LogStatus.SUCCESS
assert finished_log.data == {"result": "done"}
def test_finish_log_adds_usage_metadata(self, agent_pattern):
"""Test that finish_log adds usage to metadata."""
log = agent_pattern._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
finished_log = agent_pattern._finish_log(log, usage=usage)
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_TOKENS] == 150
assert finished_log.metadata[AgentLog.LogMetadata.TOTAL_PRICE] == Decimal("0.2")
assert finished_log.metadata[AgentLog.LogMetadata.CURRENCY] == "USD"
assert finished_log.metadata[AgentLog.LogMetadata.LLM_USAGE] == usage
class TestFindToolByName:
"""Tests for _find_tool_by_name method."""
def test_find_existing_tool(self, mock_model_instance, mock_context):
"""Test finding an existing tool by name."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("test_tool")
assert found_tool == mock_tool
def test_find_nonexistent_tool_returns_none(self, mock_model_instance, mock_context):
"""Test that finding a nonexistent tool returns None."""
mock_tool = MagicMock()
mock_tool.entity.identity.name = "test_tool"
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
found_tool = pattern._find_tool_by_name("nonexistent_tool")
assert found_tool is None
class TestMaxIterationsCapping:
"""Tests for max_iterations capping."""
def test_max_iterations_capped_at_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is capped at 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=150,
)
assert pattern.max_iterations == 99
def test_max_iterations_not_capped_when_under_99(self, mock_model_instance, mock_context):
"""Test that max_iterations is not capped when under 99."""
pattern = ConcreteAgentPattern(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
max_iterations=50,
)
assert pattern.max_iterations == 50

View File

@ -0,0 +1,332 @@
"""Tests for FunctionCallStrategy."""
from decimal import Decimal
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentLog, ExecutionContext
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
PromptMessageTool,
SystemPromptMessage,
UserPromptMessage,
)
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.to_prompt_message_tool.return_value = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={
"type": "object",
"properties": {"param1": {"type": "string", "description": "A parameter"}},
"required": ["param1"],
},
)
return tool
class TestFunctionCallStrategyInit:
"""Tests for FunctionCallStrategy initialization."""
def test_initialization(self, mock_model_instance, mock_context, mock_tool):
"""Test basic initialization."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=10,
)
assert strategy.model_instance == mock_model_instance
assert strategy.context == mock_context
assert strategy.max_iterations == 10
assert len(strategy.tools) == 1
def test_initialization_with_tool_invoke_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test initialization with tool_invoke_hook."""
from core.agent.patterns.function_call import FunctionCallStrategy
mock_hook = MagicMock()
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook
class TestConvertToolsToPromptFormat:
"""Tests for _convert_tools_to_prompt_format method."""
def test_convert_tools_returns_prompt_message_tools(self, mock_model_instance, mock_context, mock_tool):
"""Test that _convert_tools_to_prompt_format returns PromptMessageTool list."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert len(tools) == 1
assert isinstance(tools[0], PromptMessageTool)
assert tools[0].name == "test_tool"
def test_convert_tools_empty_when_no_tools(self, mock_model_instance, mock_context):
"""Test that _convert_tools_to_prompt_format returns empty list when no tools."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
tools = strategy._convert_tools_to_prompt_format()
assert tools == []
class TestAgentLogGeneration:
"""Tests for AgentLog generation during run."""
def test_round_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that round logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
max_iterations=1,
)
# Create a round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"inputs": {"query": "test"}},
)
assert round_log.label == "ROUND 1"
assert round_log.log_type == AgentLog.LogType.ROUND
assert round_log.status == AgentLog.LogStatus.START
assert "inputs" in round_log.data
def test_tool_call_log_structure(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool call logs have correct structure."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
# Create a parent round log
round_log = strategy._create_log(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
# Create a tool call log
tool_log = strategy._create_log(
label="CALL test_tool",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={"tool_name": "test_tool", "tool_args": {"param1": "value1"}},
parent_id=round_log.id,
)
assert tool_log.label == "CALL test_tool"
assert tool_log.log_type == AgentLog.LogType.TOOL_CALL
assert tool_log.parent_id == round_log.id
assert tool_log.data["tool_name"] == "test_tool"
class TestToolInvocation:
"""Tests for tool invocation."""
def test_invoke_tool_with_hook(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool invocation uses hook when provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_hook = MagicMock()
mock_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={"tool_provider_type": "test", "tool_provider": "test_id"},
)
mock_hook.return_value = ("Tool result", ["file-1"], mock_meta)
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=mock_hook,
)
result, files, meta = strategy._invoke_tool(mock_tool, {"param1": "value"}, "test_tool")
mock_hook.assert_called_once()
assert result == "Tool result"
assert files == [] # Hook returns file IDs, but _invoke_tool returns empty File list
assert meta == mock_meta
def test_invoke_tool_without_hook_attribute_set(self, mock_model_instance, mock_context, mock_tool):
"""Test that tool_invoke_hook is None when not provided."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
tool_invoke_hook=None,
)
# Verify that tool_invoke_hook is None
assert strategy.tool_invoke_hook is None
class TestUsageTracking:
"""Tests for usage tracking across rounds."""
def test_round_usage_is_separate_from_total(self, mock_model_instance, mock_context):
"""Test that round usage is tracked separately from total."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
# Simulate two rounds of usage
total_usage: dict = {"usage": None}
round1_usage: dict = {"usage": None}
round2_usage: dict = {"usage": None}
# Round 1
usage1 = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round1_usage, usage1)
strategy._accumulate_usage(total_usage, usage1)
# Round 2
usage2 = LLMUsage(
prompt_tokens=200,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.2"),
completion_tokens=100,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.2"),
total_tokens=300,
total_price=Decimal("0.4"),
currency="USD",
latency=0.5,
)
strategy._accumulate_usage(round2_usage, usage2)
strategy._accumulate_usage(total_usage, usage2)
# Verify round usage is separate
assert round1_usage["usage"].total_tokens == 150
assert round2_usage["usage"].total_tokens == 300
# Verify total is accumulated
assert total_usage["usage"].total_tokens == 450
class TestPromptMessageHandling:
"""Tests for prompt message handling."""
def test_messages_include_system_and_user(self, mock_model_instance, mock_context, mock_tool):
"""Test that messages include system and user prompts."""
from core.agent.patterns.function_call import FunctionCallStrategy
strategy = FunctionCallStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello"),
]
# Just verify the messages can be processed
assert len(messages) == 2
assert isinstance(messages[0], SystemPromptMessage)
assert isinstance(messages[1], UserPromptMessage)
def test_assistant_message_with_tool_calls(self, mock_model_instance, mock_context, mock_tool):
"""Test that assistant messages can contain tool calls."""
from core.model_runtime.entities.message_entities import AssistantPromptMessage
tool_call = AssistantPromptMessage.ToolCall(
id="call_123",
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name="test_tool",
arguments='{"param1": "value1"}',
),
)
assistant_message = AssistantPromptMessage(
content="I'll help you with that.",
tool_calls=[tool_call],
)
assert len(assistant_message.tool_calls) == 1
assert assistant_message.tool_calls[0].function.name == "test_tool"

View File

@ -0,0 +1,224 @@
"""Tests for ReActStrategy."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import ExecutionContext
from core.agent.patterns.react import ReActStrategy
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
@pytest.fixture
def mock_tool():
"""Create a mock tool."""
from core.model_runtime.entities.message_entities import PromptMessageTool
tool = MagicMock()
tool.entity.identity.name = "test_tool"
tool.entity.identity.provider = "test_provider"
# Use real PromptMessageTool for proper serialization
prompt_tool = PromptMessageTool(
name="test_tool",
description="A test tool",
parameters={"type": "object", "properties": {}},
)
tool.to_prompt_message_tool.return_value = prompt_tool
return tool
class TestReActStrategyInit:
"""Tests for ReActStrategy initialization."""
def test_init_with_instruction(self, mock_model_instance, mock_context):
"""Test that instruction is stored correctly."""
instruction = "You are a helpful assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
assert strategy.instruction == instruction
def test_init_with_empty_instruction(self, mock_model_instance, mock_context):
"""Test that empty instruction is handled correctly."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
assert strategy.instruction == ""
class TestBuildPromptWithReactFormat:
"""Tests for _build_prompt_with_react_format method."""
def test_replace_tools_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tools}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
UserPromptMessage(content="Hello"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# The tools placeholder should be replaced with JSON
assert "{{tools}}" not in result[0].content
assert "test_tool" in result[0].content
def test_replace_tool_names_placeholder(self, mock_model_instance, mock_context, mock_tool):
"""Test that {{tool_names}} placeholder is replaced."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[mock_tool],
context=mock_context,
)
system_content = "Valid actions: {{tool_names}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
assert "{{tool_names}}" not in result[0].content
assert '"test_tool"' in result[0].content
def test_replace_instruction_placeholder(self, mock_model_instance, mock_context):
"""Test that {{instruction}} placeholder is replaced."""
instruction = "You are a helpful coding assistant."
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
instruction=instruction,
)
system_content = "{{instruction}}\n\nYou have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], True, instruction)
assert "{{instruction}}" not in result[0].content
assert instruction in result[0].content
def test_no_tools_available_message(self, mock_model_instance, mock_context):
"""Test that 'No tools available' is shown when include_tools is False."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
system_content = "You have access to: {{tools}}"
messages = [
SystemPromptMessage(content=system_content),
]
result = strategy._build_prompt_with_react_format(messages, [], False)
assert "No tools available" in result[0].content
def test_scratchpad_appended_as_assistant_message(self, mock_model_instance, mock_context):
"""Test that agent scratchpad is appended as AssistantPromptMessage."""
from core.agent.entities import AgentScratchpadUnit
from core.model_runtime.entities import AssistantPromptMessage
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
scratchpad = [
AgentScratchpadUnit(
thought="I need to search for information",
action_str='{"action": "search", "action_input": "query"}',
observation="Search results here",
)
]
result = strategy._build_prompt_with_react_format(messages, scratchpad, True)
# The last message should be an AssistantPromptMessage with scratchpad content
assert len(result) == 3
assert isinstance(result[-1], AssistantPromptMessage)
assert "I need to search for information" in result[-1].content
assert "Search results here" in result[-1].content
def test_empty_scratchpad_no_extra_message(self, mock_model_instance, mock_context):
"""Test that empty scratchpad doesn't add extra message."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
messages = [
SystemPromptMessage(content="System prompt"),
UserPromptMessage(content="User query"),
]
result = strategy._build_prompt_with_react_format(messages, [], True)
# Should only have the original 2 messages
assert len(result) == 2
def test_original_messages_not_modified(self, mock_model_instance, mock_context):
"""Test that original messages list is not modified."""
strategy = ReActStrategy(
model_instance=mock_model_instance,
tools=[],
context=mock_context,
)
original_content = "Original system prompt {{tools}}"
messages = [
SystemPromptMessage(content=original_content),
]
strategy._build_prompt_with_react_format(messages, [], True)
# Original message should not be modified
assert messages[0].content == original_content

View File

@ -0,0 +1,203 @@
"""Tests for StrategyFactory."""
from unittest.mock import MagicMock
import pytest
from core.agent.entities import AgentEntity, ExecutionContext
from core.agent.patterns.function_call import FunctionCallStrategy
from core.agent.patterns.react import ReActStrategy
from core.agent.patterns.strategy_factory import StrategyFactory
from core.model_runtime.entities.model_entities import ModelFeature
@pytest.fixture
def mock_model_instance():
"""Create a mock model instance."""
model_instance = MagicMock()
model_instance.model = "test-model"
model_instance.provider = "test-provider"
return model_instance
@pytest.fixture
def mock_context():
"""Create a mock execution context."""
return ExecutionContext(
user_id="test-user",
app_id="test-app",
conversation_id="test-conversation",
message_id="test-message",
tenant_id="test-tenant",
)
class TestStrategyFactory:
"""Tests for StrategyFactory.create_strategy method."""
def test_create_function_call_strategy_with_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports TOOL_CALL."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_multi_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports MULTI_TOOL_CALL."""
model_features = [ModelFeature.MULTI_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_function_call_strategy_with_stream_tool_call_feature(self, mock_model_instance, mock_context):
"""Test that FunctionCallStrategy is created when model supports STREAM_TOOL_CALL."""
model_features = [ModelFeature.STREAM_TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, FunctionCallStrategy)
def test_create_react_strategy_without_tool_call_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model doesn't support tool calling."""
model_features = [ModelFeature.VISION] # Only vision, no tool calling
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_create_react_strategy_with_empty_features(self, mock_model_instance, mock_context):
"""Test that ReActStrategy is created when model has no features."""
model_features: list[ModelFeature] = []
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
)
assert isinstance(strategy, ReActStrategy)
def test_explicit_function_calling_strategy_with_support(self, mock_model_instance, mock_context):
"""Test explicit FUNCTION_CALLING strategy selection with model support."""
model_features = [ModelFeature.TOOL_CALL]
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
assert isinstance(strategy, FunctionCallStrategy)
def test_explicit_function_calling_strategy_without_support_falls_back_to_react(
self, mock_model_instance, mock_context
):
"""Test that explicit FUNCTION_CALLING falls back to ReAct when not supported."""
model_features: list[ModelFeature] = [] # No tool calling support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING,
)
# Should fall back to ReAct since FC is not supported
assert isinstance(strategy, ReActStrategy)
def test_explicit_chain_of_thought_strategy(self, mock_model_instance, mock_context):
"""Test explicit CHAIN_OF_THOUGHT strategy selection."""
model_features = [ModelFeature.TOOL_CALL] # Even with tool call support
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
agent_strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
)
assert isinstance(strategy, ReActStrategy)
def test_react_strategy_with_instruction(self, mock_model_instance, mock_context):
"""Test that ReActStrategy receives instruction parameter."""
model_features: list[ModelFeature] = []
instruction = "You are a helpful assistant."
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
instruction=instruction,
)
assert isinstance(strategy, ReActStrategy)
assert strategy.instruction == instruction
def test_max_iterations_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that max_iterations is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
max_iterations = 5
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
max_iterations=max_iterations,
)
assert strategy.max_iterations == max_iterations
def test_tool_invoke_hook_passed_to_strategy(self, mock_model_instance, mock_context):
"""Test that tool_invoke_hook is passed to the strategy."""
model_features = [ModelFeature.TOOL_CALL]
mock_hook = MagicMock()
strategy = StrategyFactory.create_strategy(
model_features=model_features,
model_instance=mock_model_instance,
context=mock_context,
tools=[],
files=[],
tool_invoke_hook=mock_hook,
)
assert strategy.tool_invoke_hook == mock_hook

View File

@ -0,0 +1,388 @@
"""Tests for AgentAppRunner."""
from decimal import Decimal
from unittest.mock import MagicMock, patch
import pytest
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentResult
from core.model_runtime.entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.llm_entities import LLMUsage
class TestOrganizePromptMessages:
"""Tests for _organize_prompt_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
# We'll patch the class to avoid complex initialization
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
# Set up required attributes
runner.config = MagicMock(spec=AgentEntity)
runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
runner.config.prompt = None
runner.app_config = MagicMock()
runner.app_config.prompt_template = MagicMock()
runner.app_config.prompt_template.simple_prompt_template = "You are a helpful assistant."
runner.history_prompt_messages = []
runner.query = "Hello"
runner._current_thoughts = []
runner.files = []
runner.model_config = MagicMock()
runner.memory = None
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_function_calling_uses_simple_prompt(self, mock_runner):
"""Test that function calling strategy uses simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.FUNCTION_CALLING
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
def test_chain_of_thought_uses_agent_prompt(self, mock_runner):
"""Test that chain of thought strategy uses agent prompt template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = AgentPromptEntity(
first_prompt="ReAct prompt template with {{tools}}",
next_iteration="Continue...",
)
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="ReAct prompt template with {{tools}}")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="ReAct prompt template with {{tools}}")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with agent prompt
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "ReAct prompt template with {{tools}}"
def test_chain_of_thought_without_prompt_falls_back(self, mock_runner):
"""Test that chain of thought without prompt falls back to simple_prompt_template."""
mock_runner.config.strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
mock_runner.config.prompt = None
with patch.object(mock_runner, "_init_system_message") as mock_init:
mock_init.return_value = [SystemPromptMessage(content="You are a helpful assistant.")]
with patch.object(mock_runner, "_organize_user_query") as mock_query:
mock_query.return_value = [UserPromptMessage(content="Hello")]
with patch("core.agent.agent_app_runner.AgentHistoryPromptTransform") as mock_transform:
mock_transform.return_value.get_prompt.return_value = [
SystemPromptMessage(content="You are a helpful assistant.")
]
result = mock_runner._organize_prompt_messages()
# Verify _init_system_message was called with simple_prompt_template
mock_init.assert_called_once()
call_args = mock_init.call_args[0]
assert call_args[0] == "You are a helpful assistant."
class TestInitSystemMessage:
"""Tests for _init_system_message method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_empty_messages_with_template(self, mock_runner):
"""Test that system message is created when messages are empty."""
result = mock_runner._init_system_message("System template", [])
assert len(result) == 1
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
def test_empty_messages_without_template(self, mock_runner):
"""Test that empty list is returned when no template and no messages."""
result = mock_runner._init_system_message("", [])
assert result == []
def test_existing_system_message_not_duplicated(self, mock_runner):
"""Test that system message is not duplicated if already present."""
existing_messages = [
SystemPromptMessage(content="Existing system"),
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("New template", existing_messages)
# Should not insert new system message
assert len(result) == 2
assert result[0].content == "Existing system"
def test_system_message_inserted_when_missing(self, mock_runner):
"""Test that system message is inserted when first message is not system."""
existing_messages = [
UserPromptMessage(content="User message"),
]
result = mock_runner._init_system_message("System template", existing_messages)
assert len(result) == 2
assert isinstance(result[0], SystemPromptMessage)
assert result[0].content == "System template"
class TestClearUserPromptImageMessages:
"""Tests for _clear_user_prompt_image_messages method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
return runner
def test_text_content_unchanged(self, mock_runner):
"""Test that text content is unchanged."""
messages = [
UserPromptMessage(content="Plain text message"),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
assert len(result) == 1
assert result[0].content == "Plain text message"
def test_original_messages_not_modified(self, mock_runner):
"""Test that original messages are not modified (deep copy)."""
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
)
messages = [
UserPromptMessage(
content=[
TextPromptMessageContent(data="Text part"),
ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
),
]
),
]
result = mock_runner._clear_user_prompt_image_messages(messages)
# Original should still have list content
assert isinstance(messages[0].content, list)
# Result should have string content
assert isinstance(result[0].content, str)
class TestToolInvokeHook:
"""Tests for _create_tool_invoke_hook method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.user_id = "test-user"
runner.tenant_id = "test-tenant"
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.trace_manager = None
runner.application_generate_entity.invoke_from = "api"
runner.application_generate_entity.app_config = MagicMock()
runner.application_generate_entity.app_config.app_id = "test-app"
runner.agent_callback = MagicMock()
runner.conversation = MagicMock()
runner.conversation.id = "test-conversation"
runner.queue_manager = MagicMock()
runner._current_message_file_ids = []
return runner
def test_hook_calls_agent_invoke(self, mock_runner):
"""Test that the hook calls ToolEngine.agent_invoke."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={
"tool_provider_type": "test_provider",
"tool_provider": "test_id",
},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
result_content, result_files, result_meta = hook(mock_tool, {"arg": "value"}, "test_tool")
# Verify ToolEngine.agent_invoke was called
mock_engine.agent_invoke.assert_called_once()
# Verify return values
assert result_content == "Tool result"
assert result_files == ["file-1", "file-2"]
assert result_meta == mock_tool_meta
def test_hook_publishes_file_events(self, mock_runner):
"""Test that the hook publishes QueueMessageFileEvent for files."""
from core.tools.entities.tool_entities import ToolInvokeMeta
mock_message = MagicMock()
mock_message.id = "test-message"
mock_tool = MagicMock()
mock_tool_meta = ToolInvokeMeta(
time_cost=0.5,
error=None,
tool_config={},
)
with patch("core.agent.agent_app_runner.ToolEngine") as mock_engine:
mock_engine.agent_invoke.return_value = ("Tool result", ["file-1", "file-2"], mock_tool_meta)
hook = mock_runner._create_tool_invoke_hook(mock_message)
hook(mock_tool, {}, "test_tool")
# Verify file events were published
assert mock_runner.queue_manager.publish.call_count == 2
assert mock_runner._current_message_file_ids == ["file-1", "file-2"]
class TestAgentLogProcessing:
"""Tests for AgentLog processing in run method."""
def test_agent_log_status_enum(self):
"""Test AgentLog status enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_agent_log_metadata_enum(self):
"""Test AgentLog metadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
def test_agent_result_structure(self):
"""Test AgentResult structure."""
usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("0.001"),
prompt_price=Decimal("0.1"),
completion_tokens=50,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("0.001"),
completion_price=Decimal("0.1"),
total_tokens=150,
total_price=Decimal("0.2"),
currency="USD",
latency=0.5,
)
result = AgentResult(
text="Final answer",
files=[],
usage=usage,
finish_reason="stop",
)
assert result.text == "Final answer"
assert result.files == []
assert result.usage == usage
assert result.finish_reason == "stop"
class TestOrganizeUserQuery:
"""Tests for _organize_user_query method."""
@pytest.fixture
def mock_runner(self):
"""Create a mock AgentAppRunner for testing."""
with patch("core.agent.agent_app_runner.BaseAgentRunner.__init__", return_value=None):
from core.agent.agent_app_runner import AgentAppRunner
runner = AgentAppRunner.__new__(AgentAppRunner)
runner.files = []
runner.application_generate_entity = MagicMock()
runner.application_generate_entity.file_upload_config = None
return runner
def test_simple_query_without_files(self, mock_runner):
"""Test organizing a simple query without files."""
result = mock_runner._organize_user_query("Hello world", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == "Hello world"
def test_query_with_files(self, mock_runner):
"""Test organizing a query with files."""
from core.file.models import File
mock_file = MagicMock(spec=File)
mock_runner.files = [mock_file]
with patch("core.agent.agent_app_runner.file_manager") as mock_fm:
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
mock_fm.to_prompt_message_content.return_value = ImagePromptMessageContent(
data="http://example.com/image.jpg",
format="url",
mime_type="image/jpeg",
)
result = mock_runner._organize_user_query("Describe this image", [])
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert isinstance(result[0].content, list)
assert len(result[0].content) == 2 # Image + Text

View File

@ -0,0 +1,191 @@
"""Tests for agent entities."""
from core.agent.entities import AgentEntity, AgentLog, AgentPromptEntity, AgentScratchpadUnit, ExecutionContext
class TestExecutionContext:
"""Tests for ExecutionContext entity."""
def test_create_with_all_fields(self):
"""Test creating ExecutionContext with all fields."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
assert context.user_id == "user-123"
assert context.app_id == "app-456"
assert context.conversation_id == "conv-789"
assert context.message_id == "msg-012"
assert context.tenant_id == "tenant-345"
def test_create_minimal(self):
"""Test creating minimal ExecutionContext."""
context = ExecutionContext.create_minimal(user_id="user-123")
assert context.user_id == "user-123"
assert context.app_id is None
assert context.conversation_id is None
assert context.message_id is None
assert context.tenant_id is None
def test_to_dict(self):
"""Test converting ExecutionContext to dictionary."""
context = ExecutionContext(
user_id="user-123",
app_id="app-456",
conversation_id="conv-789",
message_id="msg-012",
tenant_id="tenant-345",
)
result = context.to_dict()
assert result == {
"user_id": "user-123",
"app_id": "app-456",
"conversation_id": "conv-789",
"message_id": "msg-012",
"tenant_id": "tenant-345",
}
def test_with_updates(self):
"""Test creating new context with updates."""
original = ExecutionContext(
user_id="user-123",
app_id="app-456",
)
updated = original.with_updates(message_id="msg-789")
# Original should be unchanged
assert original.message_id is None
# Updated should have new value
assert updated.message_id == "msg-789"
assert updated.user_id == "user-123"
assert updated.app_id == "app-456"
class TestAgentLog:
"""Tests for AgentLog entity."""
def test_create_log_with_required_fields(self):
"""Test creating AgentLog with required fields."""
log = AgentLog(
label="ROUND 1",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={"key": "value"},
)
assert log.label == "ROUND 1"
assert log.log_type == AgentLog.LogType.ROUND
assert log.status == AgentLog.LogStatus.START
assert log.data == {"key": "value"}
assert log.id is not None # Auto-generated
assert log.parent_id is None
assert log.error is None
def test_log_type_enum(self):
"""Test LogType enum values."""
assert AgentLog.LogType.ROUND == "round"
assert AgentLog.LogType.THOUGHT == "thought"
assert AgentLog.LogType.TOOL_CALL == "tool_call"
def test_log_status_enum(self):
"""Test LogStatus enum values."""
assert AgentLog.LogStatus.START == "start"
assert AgentLog.LogStatus.SUCCESS == "success"
assert AgentLog.LogStatus.ERROR == "error"
def test_log_metadata_enum(self):
"""Test LogMetadata enum values."""
assert AgentLog.LogMetadata.STARTED_AT == "started_at"
assert AgentLog.LogMetadata.FINISHED_AT == "finished_at"
assert AgentLog.LogMetadata.ELAPSED_TIME == "elapsed_time"
assert AgentLog.LogMetadata.TOTAL_PRICE == "total_price"
assert AgentLog.LogMetadata.TOTAL_TOKENS == "total_tokens"
assert AgentLog.LogMetadata.LLM_USAGE == "llm_usage"
class TestAgentScratchpadUnit:
"""Tests for AgentScratchpadUnit entity."""
def test_is_final_with_final_answer_action(self):
"""Test is_final returns True for Final Answer action."""
unit = AgentScratchpadUnit(
thought="I know the answer",
action=AgentScratchpadUnit.Action(
action_name="Final Answer",
action_input="The answer is 42",
),
)
assert unit.is_final() is True
def test_is_final_with_tool_action(self):
"""Test is_final returns False for tool action."""
unit = AgentScratchpadUnit(
thought="I need to search",
action=AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
),
)
assert unit.is_final() is False
def test_is_final_with_no_action(self):
"""Test is_final returns True when no action."""
unit = AgentScratchpadUnit(
thought="Just thinking",
)
assert unit.is_final() is True
def test_action_to_dict(self):
"""Test Action.to_dict method."""
action = AgentScratchpadUnit.Action(
action_name="search",
action_input={"query": "test"},
)
result = action.to_dict()
assert result == {
"action": "search",
"action_input": {"query": "test"},
}
class TestAgentEntity:
"""Tests for AgentEntity."""
def test_strategy_enum(self):
"""Test Strategy enum values."""
assert AgentEntity.Strategy.CHAIN_OF_THOUGHT == "chain-of-thought"
assert AgentEntity.Strategy.FUNCTION_CALLING == "function-calling"
def test_create_with_prompt(self):
"""Test creating AgentEntity with prompt."""
prompt = AgentPromptEntity(
first_prompt="You are a helpful assistant.",
next_iteration="Continue thinking...",
)
entity = AgentEntity(
provider="openai",
model="gpt-4",
strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT,
prompt=prompt,
max_iteration=5,
)
assert entity.provider == "openai"
assert entity.model == "gpt-4"
assert entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT
assert entity.prompt == prompt
assert entity.max_iteration == 5