mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 21:26:15 +08:00
fix: resolve import errors and test failures after segment 4 merge
- Update BaseNodeData import path to dify_graph.entities.base_node_data - Change NodeType.COMMAND/FILE_UPLOAD to BuiltinNodeTypes constants - Fix system_oauth_encryption -> system_encryption rename in commands - Remove tests for deleted agent runner modules - Fix Avatar: named import + string size API in collaboration files - Add missing skill feature deps: @monaco-editor/react, react-arborist, @tanstack/react-virtual - Fix frontend test mocks: add useUserProfile, useLeaderRestoreListener, next/navigation mock, and nodeOutputVars to expected payload Made-with: Cursor
This commit is contained in:
@ -1,551 +0,0 @@
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_agent_runner import CotAgentRunner
|
||||
from core.agent.entities import AgentScratchpadUnit
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
|
||||
|
||||
class DummyRunner(CotAgentRunner):
|
||||
"""Concrete implementation for testing abstract methods."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB/session usage
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
# Minimal required defaults
|
||||
self.history_prompt_messages = []
|
||||
self.memory = None
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Prevent BaseAgentRunner __init__ from hitting database
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.organize_agent_history",
|
||||
return_value=[],
|
||||
)
|
||||
# Prepare required constructor dependencies for BaseAgentRunner
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock()
|
||||
application_generate_entity.model_conf.stop = []
|
||||
application_generate_entity.model_conf.provider = "openai"
|
||||
application_generate_entity.model_conf.parameters = {}
|
||||
application_generate_entity.trace_manager = None
|
||||
application_generate_entity.invoke_from = "test"
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock()
|
||||
app_config.agent.max_iteration = 1
|
||||
app_config.prompt_template.simple_prompt_template = "Hello {{name}}"
|
||||
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
model_instance.invoke_llm.return_value = []
|
||||
|
||||
model_config = MagicMock()
|
||||
model_config.model = "test-model"
|
||||
|
||||
queue_manager = MagicMock()
|
||||
message = MagicMock()
|
||||
|
||||
runner = DummyRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=MagicMock(),
|
||||
app_config=app_config,
|
||||
model_config=model_config,
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Patch internal methods to isolate behavior
|
||||
runner._repack_app_generate_entity = MagicMock()
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.create_agent_thought = MagicMock(return_value="thought-id")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
class TestFillInputs:
|
||||
@pytest.mark.parametrize(
|
||||
("instruction", "inputs", "expected"),
|
||||
[
|
||||
("Hello {{name}}", {"name": "John"}, "Hello John"),
|
||||
("No placeholders", {"name": "John"}, "No placeholders"),
|
||||
("{{a}}{{b}}", {"a": 1, "b": 2}, "12"),
|
||||
("{{x}}", {"x": None}, "None"),
|
||||
("", {"x": "y"}, ""),
|
||||
],
|
||||
)
|
||||
def test_fill_in_inputs(self, runner, instruction, inputs, expected):
|
||||
result = runner._fill_in_inputs_from_external_data_tools(instruction, inputs)
|
||||
assert result == expected
|
||||
|
||||
|
||||
class TestConvertDictToAction:
|
||||
def test_convert_valid_dict(self, runner):
|
||||
action_dict = {"action": "test", "action_input": {"a": 1}}
|
||||
action = runner._convert_dict_to_action(action_dict)
|
||||
assert action.action_name == "test"
|
||||
assert action.action_input == {"a": 1}
|
||||
|
||||
def test_convert_missing_keys(self, runner):
|
||||
with pytest.raises(KeyError):
|
||||
runner._convert_dict_to_action({"invalid": 1})
|
||||
|
||||
|
||||
class TestFormatAssistantMessage:
|
||||
def test_format_assistant_message_multiple_scratchpads(self, runner):
|
||||
sp1 = AgentScratchpadUnit(
|
||||
agent_response="resp1",
|
||||
thought="thought1",
|
||||
action_str="action1",
|
||||
action=AgentScratchpadUnit.Action(action_name="tool", action_input={}),
|
||||
observation="obs1",
|
||||
)
|
||||
sp2 = AgentScratchpadUnit(
|
||||
agent_response="final",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done"),
|
||||
observation=None,
|
||||
)
|
||||
result = runner._format_assistant_message([sp1, sp2])
|
||||
assert "Final Answer:" in result
|
||||
|
||||
def test_format_with_final(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="Done",
|
||||
thought="",
|
||||
action_str="",
|
||||
action=None,
|
||||
observation=None,
|
||||
)
|
||||
# Simulate final state via action name
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="Done")
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Final Answer" in result
|
||||
|
||||
def test_format_with_action_and_observation(self, runner):
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="resp",
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
action=None,
|
||||
observation="obs",
|
||||
)
|
||||
# Non-final state: provide a non-final action
|
||||
scratchpad.action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
result = runner._format_assistant_message([scratchpad])
|
||||
assert "Thought:" in result
|
||||
assert "Action:" in result
|
||||
assert "Observation:" in result
|
||||
|
||||
|
||||
class TestHandleInvokeAction:
|
||||
def test_handle_invoke_action_tool_not_present(self, runner):
|
||||
action = AgentScratchpadUnit.Action(action_name="missing", action_input={})
|
||||
response, meta = runner._handle_invoke_action(action, {}, [])
|
||||
assert "there is not a tool named" in response
|
||||
|
||||
def test_tool_with_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input=json.dumps({"a": 1}))
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("result", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, [])
|
||||
assert response == "result"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessages:
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform.get_prompt",
|
||||
return_value=[],
|
||||
)
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRun:
|
||||
def test_run_handles_empty_parser_output(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_run_with_action_and_tool_invocation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_respects_max_iteration_boundary(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 1
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.agent_callback = None
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {"tool": MagicMock()}))
|
||||
|
||||
def test_run_basic_flow(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {"name": "John"}))
|
||||
assert results
|
||||
|
||||
def test_run_max_iteration_error(self, runner, mocker):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_increase_usage_aggregation(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
runner.app_config.agent.max_iteration = 2
|
||||
|
||||
usage_1 = LLMUsage.empty_usage()
|
||||
usage_1.prompt_tokens = 1
|
||||
usage_1.completion_tokens = 1
|
||||
usage_1.total_tokens = 2
|
||||
usage_1.prompt_price = 1
|
||||
usage_1.completion_price = 1
|
||||
usage_1.total_price = 2
|
||||
|
||||
usage_2 = LLMUsage.empty_usage()
|
||||
usage_2.prompt_tokens = 1
|
||||
usage_2.completion_tokens = 1
|
||||
usage_2.total_tokens = 2
|
||||
usage_2.prompt_price = 1
|
||||
usage_2.completion_price = 1
|
||||
usage_2.total_price = 2
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
handle_output = mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[
|
||||
[action],
|
||||
[],
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_side_effect(chunks, usage_dict):
|
||||
call_index = handle_output.call_count
|
||||
usage_dict["usage"] = usage_1 if call_index == 1 else usage_2
|
||||
return [action] if call_index == 1 else []
|
||||
|
||||
handle_output.side_effect = _handle_side_effect
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
final_usage = results[-1].delta.usage
|
||||
assert final_usage is not None
|
||||
assert final_usage.prompt_tokens == 2
|
||||
assert final_usage.completion_tokens == 2
|
||||
assert final_usage.total_tokens == 4
|
||||
assert final_usage.prompt_price == 2
|
||||
assert final_usage.completion_price == 2
|
||||
assert final_usage.total_price == 4
|
||||
|
||||
def test_run_when_no_action_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == ""
|
||||
|
||||
def test_run_usage_missing_key_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
runner.model_instance.invoke_llm = MagicMock(return_value=[])
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
def test_run_prompt_tool_update_branch(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input={})
|
||||
|
||||
# First iteration → action
|
||||
# Second iteration → no action (empty list)
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
side_effect=[[action], []],
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", [], MagicMock(to_dict=lambda: {})),
|
||||
)
|
||||
|
||||
runner.app_config.agent.max_iteration = 5
|
||||
|
||||
fake_prompt_tool = MagicMock()
|
||||
fake_prompt_tool.name = "tool"
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({"tool": MagicMock()}, [fake_prompt_tool]))
|
||||
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
runner.agent_callback = None
|
||||
|
||||
list(runner.run(message, "query", {}))
|
||||
|
||||
runner.update_prompt_message_tool.assert_called_once()
|
||||
|
||||
def test_historic_with_assistant_and_tool_calls(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="thinking")
|
||||
assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))]
|
||||
|
||||
tool_msg = ToolPromptMessage(content="obs", tool_call_id="1")
|
||||
|
||||
runner.history_prompt_messages = [assistant, tool_msg]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_historic_final_flush_branch(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
|
||||
assistant = AssistantPromptMessage(content="final")
|
||||
runner.history_prompt_messages = [assistant]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert isinstance(result, list)
|
||||
|
||||
|
||||
class TestInitReactState:
|
||||
def test_init_react_state_resets_state(self, runner, mocker):
|
||||
mocker.patch.object(runner, "_organize_historic_prompt_messages", return_value=["historic"])
|
||||
runner._agent_scratchpad = ["old"]
|
||||
runner._query = "old"
|
||||
|
||||
runner._init_react_state("new-query")
|
||||
|
||||
assert runner._query == "new-query"
|
||||
assert runner._agent_scratchpad == []
|
||||
assert runner._historic_prompt_messages == ["historic"]
|
||||
|
||||
|
||||
class TestHandleInvokeActionExtended:
|
||||
def test_tool_with_invalid_json_string_args(self, runner, mocker):
|
||||
action = AgentScratchpadUnit.Action(action_name="tool", action_input="not-json")
|
||||
tool_instance = MagicMock()
|
||||
tool_instances = {"tool": tool_instance}
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], MagicMock(to_dict=lambda: {"k": "v"})),
|
||||
)
|
||||
|
||||
message_file_ids = []
|
||||
response, meta = runner._handle_invoke_action(action, tool_instances, message_file_ids)
|
||||
|
||||
assert response == "ok"
|
||||
assert message_file_ids == ["file1"]
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
|
||||
class TestFillInputsEdgeCases:
|
||||
def test_fill_inputs_with_empty_inputs(self, runner):
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
def test_fill_inputs_with_exception_in_replace(self, runner):
|
||||
class BadValue:
|
||||
def __str__(self):
|
||||
raise Exception("fail")
|
||||
|
||||
# Should silently continue on exception
|
||||
result = runner._fill_in_inputs_from_external_data_tools("Hello {{x}}", {"x": BadValue()})
|
||||
assert result == "Hello {{x}}"
|
||||
|
||||
|
||||
class TestOrganizeHistoricPromptMessagesExtended:
|
||||
def test_user_message_flushes_scratchpad(self, runner, mocker):
|
||||
from dify_graph.model_runtime.entities.message_entities import UserPromptMessage
|
||||
|
||||
user_message = UserPromptMessage(content="Hi")
|
||||
|
||||
runner.history_prompt_messages = [user_message]
|
||||
|
||||
mock_transform = mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
)
|
||||
mock_transform.return_value.get_prompt.return_value = ["final"]
|
||||
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == ["final"]
|
||||
|
||||
def test_tool_message_without_scratchpad_raises(self, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage
|
||||
|
||||
runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")]
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
runner._organize_historic_prompt_messages([])
|
||||
|
||||
def test_agent_history_transform_invocation(self, runner, mocker):
|
||||
mock_transform = MagicMock()
|
||||
mock_transform.get_prompt.return_value = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.AgentHistoryPromptTransform",
|
||||
return_value=mock_transform,
|
||||
)
|
||||
|
||||
runner.history_prompt_messages = []
|
||||
result = runner._organize_historic_prompt_messages([])
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestRunAdditionalBranches:
|
||||
def test_run_with_no_action_final_answer_empty(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=["thinking"],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert any(hasattr(r, "delta") for r in results)
|
||||
|
||||
def test_run_with_final_answer_action_string(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="done")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "done"
|
||||
|
||||
def test_run_with_final_answer_action_dict(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input={"a": 1})
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert json.loads(results[-1].delta.message.content) == {"a": 1}
|
||||
|
||||
def test_run_with_string_final_answer(self, runner, mocker):
|
||||
message = MagicMock()
|
||||
message.id = "msg-id"
|
||||
|
||||
# Remove invalid branch: Pydantic enforces str|dict for action_input
|
||||
action = AgentScratchpadUnit.Action(action_name="Final Answer", action_input="12345")
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_agent_runner.CotAgentOutputParser.handle_react_stream_output",
|
||||
return_value=[action],
|
||||
)
|
||||
|
||||
results = list(runner.run(message, "query", {}))
|
||||
assert results[-1].delta.message.content == "12345"
|
||||
@ -1,215 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyAgentConfig,
|
||||
DummyAppConfig,
|
||||
DummyTool,
|
||||
)
|
||||
from tests.unit_tests.core.agent.conftest import (
|
||||
DummyPromptEntity as DummyPrompt,
|
||||
)
|
||||
|
||||
|
||||
class DummyFileUploadConfig:
|
||||
def __init__(self, image_config=None):
|
||||
self.image_config = image_config
|
||||
|
||||
|
||||
class DummyImageConfig:
|
||||
def __init__(self, detail=None):
|
||||
self.detail = detail
|
||||
|
||||
|
||||
class DummyGenerateEntity:
|
||||
def __init__(self, file_upload_config=None):
|
||||
self.file_upload_config = file_upload_config
|
||||
|
||||
|
||||
class DummyUnit:
|
||||
def __init__(self, final=False, thought=None, action_str=None, observation=None, agent_response=None):
|
||||
self._final = final
|
||||
self.thought = thought
|
||||
self.action_str = action_str
|
||||
self.observation = observation
|
||||
self.agent_response = agent_response
|
||||
|
||||
def is_final(self):
|
||||
return self._final
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
runner = CotChatAgentRunner.__new__(CotChatAgentRunner)
|
||||
runner._instruction = "test_instruction"
|
||||
runner._prompt_messages_tools = [DummyTool("tool1"), DummyTool("tool2")]
|
||||
runner._query = "user query"
|
||||
runner._agent_scratchpad = []
|
||||
runner.files = []
|
||||
runner.application_generate_entity = DummyGenerateEntity()
|
||||
runner._organize_historic_prompt_messages = MagicMock(return_value=["historic"])
|
||||
return runner
|
||||
|
||||
|
||||
class TestOrganizeSystemPrompt:
|
||||
def test_organize_system_prompt_success(self, runner, mocker):
|
||||
first_prompt = "Instruction: {{instruction}}, Tools: {{tools}}, Names: {{tool_names}}"
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt(first_prompt)))
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_chat_agent_runner.jsonable_encoder",
|
||||
return_value=[{"name": "tool1"}, {"name": "tool2"}],
|
||||
)
|
||||
|
||||
result = runner._organize_system_prompt()
|
||||
|
||||
assert "test_instruction" in result.content
|
||||
assert "tool1" in result.content
|
||||
assert "tool2" in result.content
|
||||
assert "tool1, tool2" in result.content
|
||||
|
||||
def test_organize_system_prompt_missing_agent(self, runner):
|
||||
runner.app_config = DummyAppConfig(agent=None)
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
def test_organize_system_prompt_missing_prompt(self, runner):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(prompt_entity=None))
|
||||
with pytest.raises(AssertionError):
|
||||
runner._organize_system_prompt()
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
@pytest.mark.parametrize("files", [None, pytest.param([], id="empty_list")])
|
||||
def test_organize_user_query_no_files(self, runner, files):
|
||||
runner.files = files
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "query"
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.LOW,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.UserPromptMessage")
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner):
|
||||
from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
mock_content = ImagePromptMessageContent(
|
||||
url="http://test",
|
||||
format="png",
|
||||
mime_type="image/png",
|
||||
)
|
||||
mock_to_prompt.return_value = mock_content
|
||||
mock_user_prompt.side_effect = lambda content: MagicMock(content=content)
|
||||
|
||||
runner.files = ["file1"]
|
||||
|
||||
image_config = DummyImageConfig(detail="high")
|
||||
runner.application_generate_entity = DummyGenerateEntity(DummyFileUploadConfig(image_config))
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
assert mock_content in result[0].content
|
||||
mock_to_prompt.assert_called_once_with(
|
||||
"file1",
|
||||
image_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
)
|
||||
|
||||
@patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content")
|
||||
def test_organize_user_query_with_text_file_no_config(self, mock_to_prompt, runner):
|
||||
mock_to_prompt.return_value = TextPromptMessageContent(data="file_content")
|
||||
runner.files = ["file1"]
|
||||
runner.application_generate_entity = DummyGenerateEntity(None)
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_no_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assert "system" in result
|
||||
assert "query" in result
|
||||
runner._organize_historic_prompt_messages.assert_called_once()
|
||||
|
||||
def test_with_final_scratchpad(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(final=True, agent_response="done")
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Final Answer: done" in combined
|
||||
|
||||
def test_with_thought_action_observation(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
unit = DummyUnit(
|
||||
final=False,
|
||||
thought="thinking",
|
||||
action_str="action",
|
||||
observation="observe",
|
||||
)
|
||||
runner._agent_scratchpad = [unit]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: thinking" in combined
|
||||
assert "Action: action" in combined
|
||||
assert "Observation: observe" in combined
|
||||
|
||||
def test_multiple_units_mixed(self, runner, mocker):
|
||||
runner.app_config = DummyAppConfig(DummyAgentConfig(DummyPrompt("{{instruction}}")))
|
||||
runner._organize_system_prompt = MagicMock(return_value="system")
|
||||
runner._organize_user_query = MagicMock(return_value=["query"])
|
||||
|
||||
units = [
|
||||
DummyUnit(final=False, thought="t1"),
|
||||
DummyUnit(final=True, agent_response="done"),
|
||||
]
|
||||
runner._agent_scratchpad = units
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
assistant_msgs = [m for m in result if hasattr(m, "content")]
|
||||
combined = "".join([m.content for m in assistant_msgs if isinstance(m.content, str)])
|
||||
assert "Thought: t1" in combined
|
||||
assert "Final Answer: done" in combined
|
||||
@ -1,234 +0,0 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# -----------------------------
|
||||
# Fixtures
|
||||
# -----------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker, dummy_tool_factory):
|
||||
runner = CotCompletionAgentRunner.__new__(CotCompletionAgentRunner)
|
||||
|
||||
runner._instruction = "Test instruction"
|
||||
runner._prompt_messages_tools = [dummy_tool_factory("toolA"), dummy_tool_factory("toolB")]
|
||||
runner._query = "What is Python?"
|
||||
runner._agent_scratchpad = []
|
||||
|
||||
mocker.patch(
|
||||
"core.agent.cot_completion_agent_runner.jsonable_encoder",
|
||||
side_effect=lambda tools: [{"name": t.name} for t in tools],
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_instruction_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeInstructionPrompt:
|
||||
def test_success_all_placeholders(
|
||||
self, runner, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = (
|
||||
"{{instruction}} | {{tools}} | {{tool_names}} | {{historic_messages}} | {{agent_scratchpad}} | {{query}}"
|
||||
)
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
result = runner._organize_instruction_prompt()
|
||||
|
||||
assert "Test instruction" in result
|
||||
assert "toolA" in result
|
||||
assert "toolB" in result
|
||||
tools_payload = json.loads(result.split(" | ")[1])
|
||||
assert {item["name"] for item in tools_payload} == {"toolA", "toolB"}
|
||||
|
||||
def test_agent_none_raises(self, runner, dummy_app_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=None)
|
||||
with pytest.raises(ValueError, match="Agent configuration is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
def test_prompt_entity_none_raises(self, runner, dummy_app_config_factory, dummy_agent_config_factory):
|
||||
runner.app_config = dummy_app_config_factory(agent=dummy_agent_config_factory(prompt_entity=None))
|
||||
with pytest.raises(ValueError, match="prompt entity is not set"):
|
||||
runner._organize_instruction_prompt()
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_historic_prompt Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizeHistoricPrompt:
|
||||
def test_with_user_and_assistant_string(self, runner, mocker):
|
||||
user_msg = UserPromptMessage(content="Hello")
|
||||
assistant_msg = AssistantPromptMessage(content="Hi there")
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[user_msg, assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Question: Hello" in result
|
||||
assert "Hi there" in result
|
||||
|
||||
def test_assistant_list_with_text_content(self, runner, mocker):
|
||||
text_content = TextPromptMessageContent(data="Partial answer")
|
||||
assistant_msg = AssistantPromptMessage(content=[text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
|
||||
assert "Partial answer" in result
|
||||
|
||||
def test_assistant_list_with_non_text_content_ignored(self, runner, mocker):
|
||||
non_text_content = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
assistant_msg = AssistantPromptMessage(content=[non_text_content])
|
||||
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[assistant_msg],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
def test_empty_history(self, runner, mocker):
|
||||
mocker.patch.object(
|
||||
runner,
|
||||
"_organize_historic_prompt_messages",
|
||||
return_value=[],
|
||||
)
|
||||
|
||||
result = runner._organize_historic_prompt()
|
||||
assert result == ""
|
||||
|
||||
|
||||
# ======================================================
|
||||
# _organize_prompt_messages Tests
|
||||
# ======================================================
|
||||
|
||||
|
||||
class TestOrganizePromptMessages:
|
||||
def test_full_flow_with_scratchpad(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="History\n")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(final=False, thought="Thinking", action_str="Act", observation="Obs"),
|
||||
dummy_scratchpad_unit_factory(final=True, agent_response="Done"),
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], UserPromptMessage)
|
||||
|
||||
content = result[0].content
|
||||
|
||||
assert "History" in content
|
||||
assert "Thought: Thinking" in content
|
||||
assert "Action: Act" in content
|
||||
assert "Observation: Obs" in content
|
||||
assert "Final Answer: Done" in content
|
||||
assert "Question: What is Python?" in content
|
||||
|
||||
def test_no_scratchpad(
|
||||
self, runner, mocker, dummy_app_config_factory, dummy_agent_config_factory, dummy_prompt_entity_factory
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = None
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
|
||||
assert "Question: What is Python?" in result[0].content
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("thought", "action", "observation"),
|
||||
[
|
||||
("T", None, None),
|
||||
("T", "A", None),
|
||||
("T", None, "O"),
|
||||
],
|
||||
)
|
||||
def test_partial_scratchpad_units(
|
||||
self,
|
||||
runner,
|
||||
mocker,
|
||||
thought,
|
||||
action,
|
||||
observation,
|
||||
dummy_app_config_factory,
|
||||
dummy_agent_config_factory,
|
||||
dummy_prompt_entity_factory,
|
||||
dummy_scratchpad_unit_factory,
|
||||
):
|
||||
template = "SYS {{historic_messages}} {{agent_scratchpad}} {{query}}"
|
||||
|
||||
runner.app_config = dummy_app_config_factory(
|
||||
agent=dummy_agent_config_factory(prompt_entity=dummy_prompt_entity_factory(template))
|
||||
)
|
||||
|
||||
mocker.patch.object(runner, "_organize_historic_prompt", return_value="")
|
||||
|
||||
runner._agent_scratchpad = [
|
||||
dummy_scratchpad_unit_factory(
|
||||
final=False,
|
||||
thought=thought,
|
||||
action_str=action,
|
||||
observation=observation,
|
||||
)
|
||||
]
|
||||
|
||||
result = runner._organize_prompt_messages()
|
||||
content = result[0].content
|
||||
|
||||
assert "Thought:" in content
|
||||
if action:
|
||||
assert "Action:" in content
|
||||
if observation:
|
||||
assert "Observation:" in content
|
||||
@ -1,452 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.agent.errors import AgentMaxIterationError
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueMessageFileEvent
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
DocumentPromptMessageContent,
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
# ==============================
|
||||
# Dummy Helper Classes
|
||||
# ==============================
|
||||
|
||||
|
||||
def build_usage(pt=1, ct=1, tt=2) -> LLMUsage:
|
||||
usage = LLMUsage.empty_usage()
|
||||
usage.prompt_tokens = pt
|
||||
usage.completion_tokens = ct
|
||||
usage.total_tokens = tt
|
||||
usage.prompt_price = 0
|
||||
usage.completion_price = 0
|
||||
usage.total_price = 0
|
||||
return usage
|
||||
|
||||
|
||||
class DummyMessage:
|
||||
def __init__(self, content: str | None = None, tool_calls: list[Any] | None = None):
|
||||
self.content: str | None = content
|
||||
self.tool_calls: list[Any] = tool_calls or []
|
||||
|
||||
|
||||
class DummyDelta:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
|
||||
|
||||
class DummyChunk:
|
||||
def __init__(self, message: DummyMessage | None = None, usage: LLMUsage | None = None):
|
||||
self.delta: DummyDelta = DummyDelta(message=message, usage=usage)
|
||||
|
||||
|
||||
class DummyResult:
|
||||
def __init__(
|
||||
self,
|
||||
message: DummyMessage | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
prompt_messages: list[DummyMessage] | None = None,
|
||||
):
|
||||
self.message: DummyMessage | None = message
|
||||
self.usage: LLMUsage | None = usage
|
||||
self.prompt_messages: list[DummyMessage] = prompt_messages or []
|
||||
self.system_fingerprint: str = ""
|
||||
|
||||
|
||||
# ==============================
|
||||
# Fixtures
|
||||
# ==============================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner(mocker):
|
||||
# Completely bypass BaseAgentRunner __init__ to avoid DB / Flask context
|
||||
mocker.patch(
|
||||
"core.agent.base_agent_runner.BaseAgentRunner.__init__",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
# Patch streaming chunk models to avoid validation on dummy message objects
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunk", MagicMock)
|
||||
mocker.patch("core.agent.fc_agent_runner.LLMResultChunkDelta", MagicMock)
|
||||
|
||||
app_config = MagicMock()
|
||||
app_config.agent = MagicMock(max_iteration=2)
|
||||
app_config.prompt_template = MagicMock(simple_prompt_template="system")
|
||||
|
||||
application_generate_entity = MagicMock()
|
||||
application_generate_entity.model_conf = MagicMock(parameters={}, stop=None)
|
||||
application_generate_entity.trace_manager = MagicMock()
|
||||
application_generate_entity.invoke_from = "test"
|
||||
application_generate_entity.app_config = MagicMock(app_id="app")
|
||||
application_generate_entity.file_upload_config = None
|
||||
|
||||
queue_manager = MagicMock()
|
||||
model_instance = MagicMock()
|
||||
model_instance.model = "test-model"
|
||||
model_instance.model_name = "test-model"
|
||||
|
||||
message = MagicMock(id="msg1")
|
||||
conversation = MagicMock(id="conv1")
|
||||
|
||||
runner = FunctionCallAgentRunner(
|
||||
tenant_id="tenant",
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation,
|
||||
app_config=app_config,
|
||||
model_config=MagicMock(),
|
||||
config=MagicMock(),
|
||||
queue_manager=queue_manager,
|
||||
message=message,
|
||||
user_id="user",
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
# Manually inject required attributes normally set by BaseAgentRunner
|
||||
runner.tenant_id = "tenant"
|
||||
runner.application_generate_entity = application_generate_entity
|
||||
runner.conversation = conversation
|
||||
runner.app_config = app_config
|
||||
runner.model_config = MagicMock()
|
||||
runner.config = MagicMock()
|
||||
runner.queue_manager = queue_manager
|
||||
runner.message = message
|
||||
runner.user_id = "user"
|
||||
runner.model_instance = model_instance
|
||||
|
||||
runner.stream_tool_call = False
|
||||
runner.memory = None
|
||||
runner.history_prompt_messages = []
|
||||
runner._current_thoughts = []
|
||||
runner.files = []
|
||||
runner.agent_callback = MagicMock()
|
||||
|
||||
runner._init_prompt_tools = MagicMock(return_value=({}, []))
|
||||
runner.create_agent_thought = MagicMock(return_value="thought1")
|
||||
runner.save_agent_thought = MagicMock()
|
||||
runner.recalc_llm_max_tokens = MagicMock()
|
||||
runner.update_prompt_message_tool = MagicMock()
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
# ==============================
|
||||
# Tool Call Checks
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestToolCallChecks:
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_tool_calls(self, runner, tool_calls, expected):
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_tool_calls(chunk) is expected
|
||||
|
||||
@pytest.mark.parametrize(("tool_calls", "expected"), [([], False), ([MagicMock()], True)])
|
||||
def test_check_blocking_tool_calls(self, runner, tool_calls, expected):
|
||||
result = DummyResult(message=DummyMessage(tool_calls=tool_calls))
|
||||
assert runner.check_blocking_tool_calls(result) is expected
|
||||
|
||||
|
||||
# ==============================
|
||||
# Extract Tool Calls
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestExtractToolCalls:
|
||||
def test_extract_tool_calls_with_valid_json(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {"a": 1})]
|
||||
|
||||
def test_extract_tool_calls_empty_arguments(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = ""
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_tool_calls(chunk)
|
||||
|
||||
assert calls == [("1", "tool", {})]
|
||||
|
||||
def test_extract_blocking_tool_calls(self, runner):
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "2"
|
||||
tool_call.function.name = "block"
|
||||
tool_call.function.arguments = json.dumps({"x": 2})
|
||||
|
||||
result = DummyResult(message=DummyMessage(tool_calls=[tool_call]))
|
||||
calls = runner.extract_blocking_tool_calls(result)
|
||||
|
||||
assert calls == [("2", "block", {"x": 2})]
|
||||
|
||||
|
||||
# ==============================
|
||||
# System Message Initialization
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestInitSystemMessage:
|
||||
def test_init_system_message_empty_prompt_messages(self, runner):
|
||||
result = runner._init_system_message("system", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_init_system_message_insert_at_start(self, runner):
|
||||
msgs = [MagicMock()]
|
||||
result = runner._init_system_message("system", msgs)
|
||||
assert result[0].content == "system"
|
||||
|
||||
def test_init_system_message_no_template(self, runner):
|
||||
result = runner._init_system_message("", [])
|
||||
assert result == []
|
||||
|
||||
|
||||
# ==============================
|
||||
# Organize User Query
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestOrganizeUserQuery:
|
||||
def test_without_files(self, runner):
|
||||
result = runner._organize_user_query("query", [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_none_query(self, runner):
|
||||
result = runner._organize_user_query(None, [])
|
||||
assert len(result) == 1
|
||||
|
||||
def test_with_files_uses_image_detail_config(self, runner, mocker):
|
||||
file_content = TextPromptMessageContent(data="file-content")
|
||||
mock_to_prompt = mocker.patch(
|
||||
"core.agent.fc_agent_runner.file_manager.to_prompt_message_content",
|
||||
return_value=file_content,
|
||||
)
|
||||
|
||||
image_config = MagicMock(detail=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
runner.application_generate_entity.file_upload_config = MagicMock(image_config=image_config)
|
||||
runner.files = ["file1"]
|
||||
|
||||
result = runner._organize_user_query("query", [])
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0].content, list)
|
||||
mock_to_prompt.assert_called_once_with("file1", image_detail_config=ImagePromptMessageContent.DETAIL.HIGH)
|
||||
|
||||
|
||||
# ==============================
|
||||
# Clear User Prompt Images
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestClearUserPromptImageMessages:
|
||||
def test_clear_text_and_image_content(self, runner):
|
||||
text = MagicMock()
|
||||
text.type = "text"
|
||||
text.data = "hello"
|
||||
|
||||
image = MagicMock()
|
||||
image.type = "image"
|
||||
image.data = "img"
|
||||
|
||||
user_msg = MagicMock()
|
||||
user_msg.__class__.__name__ = "UserPromptMessage"
|
||||
user_msg.content = [text, image]
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_clear_includes_file_placeholder(self, runner):
|
||||
text = TextPromptMessageContent(data="hello")
|
||||
image = ImagePromptMessageContent(format="url", mime_type="image/png")
|
||||
document = DocumentPromptMessageContent(format="url", mime_type="application/pdf")
|
||||
|
||||
user_msg = UserPromptMessage(content=[text, image, document])
|
||||
|
||||
result = runner._clear_user_prompt_image_messages([user_msg])
|
||||
|
||||
assert result[0].content == "hello\n[image]\n[file]"
|
||||
|
||||
|
||||
# ==============================
|
||||
# Run Method Tests
|
||||
# ==============================
|
||||
|
||||
|
||||
class TestRunMethod:
|
||||
def test_run_non_streaming_no_tool_calls(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
dummy_message = DummyMessage(content="hello")
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
runner.queue_manager.publish.assert_called()
|
||||
|
||||
queue_calls = runner.queue_manager.publish.call_args_list
|
||||
assert any(call.args and call.args[0].__class__.__name__ == "QueueMessageEndEvent" for call in queue_calls)
|
||||
|
||||
def test_run_streaming_branch(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_streaming_tool_calls_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
chunk = DummyChunk(message=DummyMessage(content=content, tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [generator(), final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_non_streaming_list_content(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
content = [TextPromptMessageContent(data="hi")]
|
||||
dummy_message = DummyMessage(content=content)
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
assert runner.save_agent_thought.call_args.kwargs["thought"] == "hi"
|
||||
|
||||
def test_run_streaming_tool_call_inputs_type_error(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
runner.stream_tool_call = True
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
chunk = DummyChunk(message=DummyMessage(content="hi", tool_calls=[tool_call]), usage=build_usage())
|
||||
|
||||
def generator():
|
||||
yield chunk
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = generator()
|
||||
|
||||
real_dumps = json.dumps
|
||||
|
||||
def flaky_dumps(obj, *args, **kwargs):
|
||||
if kwargs.get("ensure_ascii") is False:
|
||||
return real_dumps(obj, *args, **kwargs)
|
||||
raise TypeError("boom")
|
||||
|
||||
mocker.patch("core.agent.fc_agent_runner.json.dumps", side_effect=flaky_dumps)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) == 1
|
||||
|
||||
def test_run_with_missing_tool_instance(self, runner):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "missing"
|
||||
tool_call.function.arguments = json.dumps({})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_message = DummyMessage(content="done", tool_calls=[])
|
||||
final_result = DummyResult(message=final_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
|
||||
def test_run_with_tool_instance_and_files(self, runner, mocker):
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = json.dumps({"a": 1})
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
final_result = DummyResult(message=DummyMessage(content="done", tool_calls=[]), usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.side_effect = [result, final_result]
|
||||
|
||||
tool_instance = MagicMock()
|
||||
prompt_tool = MagicMock()
|
||||
prompt_tool.name = "tool"
|
||||
runner._init_prompt_tools.return_value = ({"tool": tool_instance}, [prompt_tool])
|
||||
|
||||
tool_invoke_meta = MagicMock()
|
||||
tool_invoke_meta.to_dict.return_value = {"ok": True}
|
||||
mocker.patch(
|
||||
"core.agent.fc_agent_runner.ToolEngine.agent_invoke",
|
||||
return_value=("ok", ["file1"], tool_invoke_meta),
|
||||
)
|
||||
|
||||
outputs = list(runner.run(message, "query"))
|
||||
assert len(outputs) >= 1
|
||||
assert any(
|
||||
isinstance(call.args[0], QueueMessageFileEvent)
|
||||
and call.args[0].message_file_id == "file1"
|
||||
and call.args[1] == PublishFrom.APPLICATION_MANAGER
|
||||
for call in runner.queue_manager.publish.call_args_list
|
||||
)
|
||||
|
||||
def test_run_max_iteration_error(self, runner):
|
||||
runner.app_config.agent.max_iteration = 0
|
||||
|
||||
message = MagicMock(id="m1")
|
||||
|
||||
tool_call = MagicMock()
|
||||
tool_call.id = "1"
|
||||
tool_call.function.name = "tool"
|
||||
tool_call.function.arguments = "{}"
|
||||
|
||||
dummy_message = DummyMessage(content="", tool_calls=[tool_call])
|
||||
result = DummyResult(message=dummy_message, usage=build_usage())
|
||||
|
||||
runner.model_instance.invoke_llm.return_value = result
|
||||
|
||||
with pytest.raises(AgentMaxIterationError):
|
||||
list(runner.run(message, "query"))
|
||||
@ -422,7 +422,10 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner")
|
||||
|
||||
response = generator._generate(
|
||||
workflow=SimpleNamespace(features={"feature": True}),
|
||||
workflow=SimpleNamespace(
|
||||
features={"feature": True},
|
||||
get_feature=lambda key: SimpleNamespace(enabled=False),
|
||||
),
|
||||
user=SimpleNamespace(id="user"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
application_generate_entity=application_generate_entity,
|
||||
@ -517,7 +520,10 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
|
||||
response = generator._generate(
|
||||
workflow=SimpleNamespace(features={}),
|
||||
workflow=SimpleNamespace(
|
||||
features={},
|
||||
get_feature=lambda key: SimpleNamespace(enabled=False),
|
||||
),
|
||||
user=SimpleNamespace(id="user"),
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
application_generate_entity=application_generate_entity,
|
||||
|
||||
@ -125,7 +125,20 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
)
|
||||
)
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=None)
|
||||
event = SimpleNamespace(
|
||||
text="hi",
|
||||
from_variable_selector=None,
|
||||
tool_call=None,
|
||||
tool_result=None,
|
||||
chunk_type=None,
|
||||
node_id=None,
|
||||
model_provider=None,
|
||||
model_name=None,
|
||||
model_icon=None,
|
||||
model_icon_dark=None,
|
||||
model_usage=None,
|
||||
model_duration=None,
|
||||
)
|
||||
|
||||
responses = list(pipeline._handle_text_chunk_event(event))
|
||||
|
||||
@ -389,7 +402,20 @@ class TestAdvancedChatGenerateTaskPipeline:
|
||||
|
||||
pipeline._message_cycle_manager = SimpleNamespace(message_to_stream_response=lambda **kwargs: "chunk")
|
||||
|
||||
event = SimpleNamespace(text="hi", from_variable_selector=["a"])
|
||||
event = SimpleNamespace(
|
||||
text="hi",
|
||||
from_variable_selector=["a"],
|
||||
tool_call=None,
|
||||
tool_result=None,
|
||||
chunk_type=None,
|
||||
node_id=None,
|
||||
model_provider=None,
|
||||
model_name=None,
|
||||
model_icon=None,
|
||||
model_icon_dark=None,
|
||||
model_usage=None,
|
||||
model_duration=None,
|
||||
)
|
||||
queue_message = SimpleNamespace(event=event)
|
||||
|
||||
responses = list(
|
||||
|
||||
@ -134,13 +134,10 @@ class TestAgentChatAppRunnerRun:
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(), mocker.MagicMock())
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
],
|
||||
"mode",
|
||||
[LLMMode.CHAT, LLMMode.COMPLETION],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -184,7 +181,7 @@ class TestAgentChatAppRunnerRun:
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch(f"core.app.apps.agent_chat.app_runner.{expected_runner}", runner_cls)
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
@ -196,7 +193,8 @@ class TestAgentChatAppRunnerRun:
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
def test_run_uses_agent_app_runner_regardless_of_mode(self, runner, mocker):
|
||||
"""After refactoring, AgentAppRunner is used for all strategies and LLM modes."""
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@ -226,7 +224,7 @@ class TestAgentChatAppRunnerRun:
|
||||
|
||||
model_schema = mocker.MagicMock()
|
||||
model_schema.features = []
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: "invalid"}
|
||||
model_schema.model_properties = {ModelPropertyKey.MODE: LLMMode.CHAT}
|
||||
|
||||
llm_instance = mocker.MagicMock()
|
||||
llm_instance.model_type_instance.get_model_schema.return_value = model_schema
|
||||
@ -239,8 +237,16 @@ class TestAgentChatAppRunnerRun:
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
@ -286,7 +292,7 @@ class TestAgentChatAppRunnerRun:
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
@ -366,10 +372,11 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
def test_run_any_strategy_uses_agent_app_runner(self, runner, mocker):
|
||||
"""After refactoring, any agent strategy uses AgentAppRunner."""
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
app_config.agent = mocker.MagicMock(strategy="custom", provider="p", model="m")
|
||||
|
||||
generate_entity = mocker.MagicMock(
|
||||
app_config=app_config,
|
||||
@ -409,5 +416,13 @@ class TestAgentChatAppRunnerRun:
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
@ -2,7 +2,7 @@ from unittest.mock import MagicMock
|
||||
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph_events import NodeRunStreamChunkEvent
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ def test_skip_empty_final_chunk() -> None:
|
||||
empty_final_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
@ -33,7 +33,7 @@ def test_skip_empty_final_chunk() -> None:
|
||||
normal_event = NodeRunStreamChunkEvent(
|
||||
id="exec",
|
||||
node_id="node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["node", "text"],
|
||||
chunk="hi",
|
||||
is_final=False,
|
||||
|
||||
@ -153,7 +153,7 @@ class TestOnToolEnd:
|
||||
|
||||
class TestReturnRetrieverResourceInfo:
|
||||
def test_publish_called(self, handler, mock_queue_manager, mocker):
|
||||
mock_event = mocker.patch("core.callback_handler.index_tool_callback_handler.QueueRetrieverResourcesEvent")
|
||||
mock_event = mocker.patch("core.app.entities.queue_entities.QueueRetrieverResourcesEvent")
|
||||
|
||||
resources = [mocker.Mock()]
|
||||
|
||||
|
||||
@ -18,8 +18,6 @@ from core.llm_generator.output_parser.structured_output import (
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMResultWithStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
@ -257,7 +255,6 @@ class TestStructuredOutput:
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={"type": "object"},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@ -267,6 +264,7 @@ class TestStructuredOutput:
|
||||
def test_invoke_llm_with_structured_output_no_stream_prompt_based(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.features = []
|
||||
model_schema.parameter_rules = [
|
||||
ParameterRule(
|
||||
name="response_format",
|
||||
@ -294,7 +292,6 @@ class TestStructuredOutput:
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={"type": "object"},
|
||||
stream=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
@ -304,6 +301,7 @@ class TestStructuredOutput:
|
||||
def test_invoke_llm_with_structured_output_no_string_error(self):
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.features = []
|
||||
model_schema.parameter_rules = []
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
@ -319,83 +317,66 @@ class TestStructuredOutput:
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[],
|
||||
json_schema={},
|
||||
stream=False,
|
||||
)
|
||||
assert "Failed to parse structured output, LLM result is not a string" in str(excinfo.value)
|
||||
assert "Failed to parse structured output" in str(excinfo.value)
|
||||
|
||||
def test_invoke_llm_with_structured_output_stream(self):
|
||||
def test_invoke_llm_with_structured_output_returns_result(self):
|
||||
"""After stream removal, invoke_llm_with_structured_output always returns non-streaming."""
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.features = []
|
||||
model_schema.parameter_rules = []
|
||||
model_schema.model = "gpt-4"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
mock_result = MagicMock(spec=LLMResult)
|
||||
mock_result.message = AssistantPromptMessage(content='{"key": "value"}')
|
||||
mock_result.model = "gpt-4"
|
||||
mock_result.usage = LLMUsage.empty_usage()
|
||||
mock_result.system_fingerprint = "fp1"
|
||||
mock_result.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
|
||||
# Mock chunks
|
||||
chunk1 = MagicMock(spec=LLMResultChunk)
|
||||
chunk1.delta = LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content='{"key": '), usage=LLMUsage.empty_usage()
|
||||
)
|
||||
chunk1.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk1.system_fingerprint = "fp1"
|
||||
model_instance.invoke_llm.return_value = mock_result
|
||||
|
||||
chunk2 = MagicMock(spec=LLMResultChunk)
|
||||
chunk2.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content='"value"}'))
|
||||
chunk2.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk2.system_fingerprint = "fp1"
|
||||
|
||||
chunk3 = MagicMock(spec=LLMResultChunk)
|
||||
chunk3.delta = LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(
|
||||
content=[
|
||||
TextPromptMessageContent(data=" "),
|
||||
]
|
||||
),
|
||||
)
|
||||
chunk3.prompt_messages = [UserPromptMessage(content="hi")]
|
||||
chunk3.system_fingerprint = "fp1"
|
||||
|
||||
event4 = MagicMock()
|
||||
event4.delta = LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=""))
|
||||
|
||||
model_instance.invoke_llm.return_value = [chunk1, chunk2, chunk3, event4]
|
||||
|
||||
generator = invoke_llm_with_structured_output(
|
||||
result = invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[UserPromptMessage(content="hi")],
|
||||
json_schema={},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
chunks = list(generator)
|
||||
assert len(chunks) == 5
|
||||
assert chunks[-1].structured_output == {"key": "value"}
|
||||
assert chunks[-1].system_fingerprint == "fp1"
|
||||
assert chunks[-1].prompt_messages == [UserPromptMessage(content="hi")]
|
||||
assert isinstance(result, LLMResultWithStructuredOutput)
|
||||
assert result.structured_output == {"key": "value"}
|
||||
assert result.system_fingerprint == "fp1"
|
||||
assert result.prompt_messages == [UserPromptMessage(content="hi")]
|
||||
|
||||
def test_invoke_llm_with_structured_output_stream_no_id_events(self):
|
||||
def test_invoke_llm_with_structured_output_empty_response_error(self):
|
||||
"""When the model returns a non-parseable result, an error is raised."""
|
||||
model_schema = MagicMock(spec=AIModelEntity)
|
||||
model_schema.support_structure_output = False
|
||||
model_schema.features = []
|
||||
model_schema.parameter_rules = []
|
||||
model_schema.model = "gpt-4"
|
||||
|
||||
model_instance = MagicMock(spec=ModelInstance)
|
||||
model_instance.invoke_llm.return_value = []
|
||||
mock_result = MagicMock(spec=LLMResult)
|
||||
mock_result.message = AssistantPromptMessage(content="")
|
||||
mock_result.model = "gpt-4"
|
||||
mock_result.usage = LLMUsage.empty_usage()
|
||||
mock_result.system_fingerprint = "fp1"
|
||||
mock_result.prompt_messages = []
|
||||
|
||||
generator = invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[],
|
||||
json_schema={},
|
||||
stream=True,
|
||||
)
|
||||
model_instance.invoke_llm.return_value = mock_result
|
||||
|
||||
with pytest.raises(OutputParserError):
|
||||
list(generator)
|
||||
invoke_llm_with_structured_output(
|
||||
provider="openai",
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=[],
|
||||
json_schema={},
|
||||
)
|
||||
|
||||
def test_parse_structured_output_empty_string(self):
|
||||
with pytest.raises(OutputParserError):
|
||||
|
||||
@ -6,6 +6,7 @@ import pytest
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.llm_generator.output_models import InstructionModifyOutput
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
|
||||
@ -317,15 +318,15 @@ class TestLLMGenerator:
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
pydantic_response = InstructionModifyOutput(modified="prompt", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt", "message": "done"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
@ -335,14 +336,15 @@ class TestLLMGenerator:
|
||||
last_run.error = "e"
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt"}
|
||||
pydantic_response = InstructionModifyOutput(modified="prompt", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert result == {"modified": "prompt", "message": "done"}
|
||||
|
||||
def test_instruction_modify_workflow_app_not_found(self):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
@ -371,27 +373,27 @@ class TestLLMGenerator:
|
||||
last_run.node_type = "llm"
|
||||
last_run.status = "s"
|
||||
last_run.error = "e"
|
||||
# Return regular values, not Mocks
|
||||
last_run.execution_metadata_dict = {"agent_log": [{"status": "s", "error": "e", "data": {}}]}
|
||||
last_run.load_full_inputs.return_value = {"in": "val"}
|
||||
|
||||
workflow_service.get_node_last_run.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow"}
|
||||
pydantic_response = InstructionModifyOutput(modified="workflow", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow", "message": "done"}
|
||||
|
||||
def test_instruction_modify_workflow_no_last_run_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
@ -403,48 +405,49 @@ class TestLLMGenerator:
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
workflow_service.get_node_last_run.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback"}
|
||||
pydantic_response = InstructionModifyOutput(modified="fallback", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback", "message": "done"}
|
||||
|
||||
def test_instruction_modify_workflow_node_type_fallback(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
mock_session.return_value.query.return_value.where.return_value.first.return_value = MagicMock()
|
||||
workflow = MagicMock()
|
||||
# Cause exception in node_type logic
|
||||
workflow.graph_dict = {"graph": {"nodes": []}}
|
||||
|
||||
workflow_service = MagicMock()
|
||||
workflow_service.get_draft_workflow.return_value = workflow
|
||||
workflow_service.get_node_last_run.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "fallback"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback"}
|
||||
pydantic_response = InstructionModifyOutput(modified="fallback", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "fallback", "message": "done"}
|
||||
|
||||
def test_instruction_modify_workflow_empty_agent_log(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session") as mock_session:
|
||||
@ -459,27 +462,27 @@ class TestLLMGenerator:
|
||||
last_run.node_type = "llm"
|
||||
last_run.status = "s"
|
||||
last_run.error = "e"
|
||||
# Return regular empty list, not a Mock
|
||||
last_run.execution_metadata_dict = {"agent_log": []}
|
||||
last_run.load_full_inputs.return_value = {}
|
||||
|
||||
workflow_service.get_node_last_run.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "workflow"}'
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow"}
|
||||
pydantic_response = InstructionModifyOutput(modified="workflow", message="done")
|
||||
with patch(
|
||||
"core.llm_generator.output_parser.structured_output.invoke_llm_with_pydantic_model",
|
||||
return_value=pydantic_response,
|
||||
):
|
||||
result = LLMGenerator.instruction_modify_workflow(
|
||||
"tenant_id",
|
||||
"flow_id",
|
||||
"node_id",
|
||||
"current",
|
||||
"instruction",
|
||||
model_config_entity,
|
||||
"ideal",
|
||||
workflow_service,
|
||||
)
|
||||
assert result == {"modified": "workflow", "message": "done"}
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
@ -513,7 +516,7 @@ class TestLLMGenerator:
|
||||
"tenant_id", "flow_id", "current", "instruction", model_config_entity, "ideal"
|
||||
)
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
assert "Failed to parse structured output" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
|
||||
@ -1,51 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.utils import system_oauth_encryption as oauth_encryption
|
||||
from core.tools.utils.system_oauth_encryption import OAuthEncryptionError, SystemOAuthEncrypter
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_roundtrip():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
payload = {"client_id": "cid", "client_secret": "csecret", "grant_type": "authorization_code"}
|
||||
|
||||
encrypted = encrypter.encrypt_oauth_params(payload)
|
||||
decrypted = encrypter.decrypt_oauth_params(encrypted)
|
||||
|
||||
assert encrypted
|
||||
assert dict(decrypted) == payload
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_decrypt_validates_input():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
encrypter.decrypt_oauth_params(123) # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValueError, match="cannot be empty"):
|
||||
encrypter.decrypt_oauth_params("")
|
||||
|
||||
|
||||
def test_system_oauth_encrypter_raises_oauth_error_for_invalid_ciphertext():
|
||||
encrypter = SystemOAuthEncrypter(secret_key="test-secret")
|
||||
|
||||
with pytest.raises(OAuthEncryptionError, match="Decryption failed"):
|
||||
encrypter.decrypt_oauth_params("not-base64")
|
||||
|
||||
|
||||
def test_system_oauth_helpers_use_global_cached_instance(monkeypatch):
|
||||
monkeypatch.setattr(oauth_encryption, "_oauth_encrypter", None)
|
||||
monkeypatch.setattr("core.tools.utils.system_oauth_encryption.dify_config.SECRET_KEY", "global-secret")
|
||||
|
||||
first = oauth_encryption.get_system_oauth_encrypter()
|
||||
second = oauth_encryption.get_system_oauth_encrypter()
|
||||
assert first is second
|
||||
|
||||
encrypted = oauth_encryption.encrypt_system_oauth_params({"k": "v"})
|
||||
assert oauth_encryption.decrypt_system_oauth_params(encrypted) == {"k": "v"}
|
||||
|
||||
|
||||
def test_create_system_oauth_encrypter_factory():
|
||||
encrypter = oauth_encryption.create_system_oauth_encrypter(secret_key="factory-secret")
|
||||
assert isinstance(encrypter, SystemOAuthEncrypter)
|
||||
@ -3,7 +3,7 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.entities.tool_entities import ToolResultStatus
|
||||
from dify_graph.enums import NodeType
|
||||
from dify_graph.enums import BuiltinNodeTypes
|
||||
from dify_graph.graph.graph import Graph
|
||||
from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator
|
||||
from dify_graph.graph_engine.response_coordinator.session import ResponseSession
|
||||
@ -30,13 +30,13 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
# Mock nodes
|
||||
llm_node = MagicMock()
|
||||
llm_node.id = "llm_node"
|
||||
llm_node.node_type = NodeType.LLM
|
||||
llm_node.node_type = BuiltinNodeTypes.LLM
|
||||
llm_node.execution_type = MagicMock()
|
||||
llm_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
response_node.node_type = BuiltinNodeTypes.ANSWER
|
||||
response_node.execution_type = MagicMock()
|
||||
response_node.blocks_variable_output = MagicMock(return_value=False)
|
||||
|
||||
@ -63,7 +63,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
content_event_1 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk="Hello",
|
||||
is_final=False,
|
||||
@ -72,7 +72,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
content_event_2 = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["llm_node", "generation", "content"],
|
||||
chunk=" world",
|
||||
is_final=True,
|
||||
@ -83,7 +83,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
tool_call_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["llm_node", "generation", "tool_calls"],
|
||||
chunk='{"query": "test"}',
|
||||
is_final=True,
|
||||
@ -99,7 +99,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
tool_result_event = NodeRunStreamChunkEvent(
|
||||
id="exec_123",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["llm_node", "generation", "tool_results"],
|
||||
chunk="Found 10 results",
|
||||
is_final=True,
|
||||
@ -196,7 +196,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
|
||||
response_node = MagicMock()
|
||||
response_node.id = "response_node"
|
||||
response_node.node_type = NodeType.ANSWER
|
||||
response_node.node_type = BuiltinNodeTypes.ANSWER
|
||||
graph.nodes = {"response_node": response_node}
|
||||
graph.root_node = response_node
|
||||
|
||||
@ -211,7 +211,7 @@ class TestResponseCoordinatorObjectStreaming:
|
||||
event = NodeRunStreamChunkEvent(
|
||||
id="stream_1",
|
||||
node_id="llm_node",
|
||||
node_type=NodeType.LLM,
|
||||
node_type=BuiltinNodeTypes.LLM,
|
||||
selector=["sys", "foo"],
|
||||
chunk="hi",
|
||||
is_final=True,
|
||||
|
||||
@ -19,7 +19,7 @@ from core.virtual_environment.channel.transport import NopTransportWriteCloser
|
||||
from core.workflow.nodes.command.node import CommandNode
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
|
||||
@ -161,6 +161,7 @@ def _make_node(
|
||||
config={
|
||||
"id": "node-config-id",
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.COMMAND,
|
||||
"title": "Command",
|
||||
"command": command,
|
||||
"working_directory": working_directory,
|
||||
|
||||
Reference in New Issue
Block a user