mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 14:38:06 +08:00
fix: resolve import errors and test failures after segment 4 merge
- Update BaseNodeData import path to dify_graph.entities.base_node_data - Change NodeType.COMMAND/FILE_UPLOAD to BuiltinNodeTypes constants - Fix system_oauth_encryption -> system_encryption rename in commands - Remove tests for deleted agent runner modules - Fix Avatar: named import + string size API in collaboration files - Add missing skill feature deps: @monaco-editor/react, react-arborist, @tanstack/react-virtual - Fix frontend test mocks: add useUserProfile, useLeaderRestoreListener, next/navigation mock, and nodeOutputVars to expected payload Made-with: Cursor
This commit is contained in:
@ -1,551 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
@ -1,215 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@ -1,234 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
@ -1,452 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
Reference in New Issue
Block a user