mirror of
https://github.com/langgenius/dify.git
synced 2026-03-12 10:38:54 +08:00
test: add unit tests for some services (#32866)
Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
This commit is contained in:
@ -0,0 +1,214 @@
|
||||
"""
|
||||
Unit tests for services.advanced_prompt_template_service
|
||||
"""
|
||||
|
||||
import copy
|
||||
|
||||
from core.prompt.prompt_templates.advanced_prompt_templates import (
|
||||
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
BAICHUAN_CONTEXT,
|
||||
CHAT_APP_CHAT_PROMPT_CONFIG,
|
||||
CHAT_APP_COMPLETION_PROMPT_CONFIG,
|
||||
COMPLETION_APP_CHAT_PROMPT_CONFIG,
|
||||
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
|
||||
CONTEXT,
|
||||
)
|
||||
from models.model import AppMode
|
||||
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
|
||||
|
||||
|
||||
class TestAdvancedPromptTemplateService:
|
||||
"""Test suite for AdvancedPromptTemplateService."""
|
||||
|
||||
def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None:
|
||||
"""Test baichuan model names use baichuan context prompt."""
|
||||
# Arrange
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "chat",
|
||||
"model_name": "Baichuan2-13B",
|
||||
"has_context": "true",
|
||||
}
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
|
||||
def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None:
|
||||
"""Test non-baichuan model names use common prompt."""
|
||||
# Arrange
|
||||
args = {
|
||||
"app_mode": AppMode.CHAT,
|
||||
"model_mode": "completion",
|
||||
"model_name": "gpt-4",
|
||||
"has_context": "false",
|
||||
}
|
||||
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_prompt(args)
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None:
|
||||
"""Test invalid app mode returns empty dict."""
|
||||
# Arrange
|
||||
app_mode = "invalid"
|
||||
model_mode = "chat"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None:
|
||||
"""Test context is prepended for completion prompt when has_context is true."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT)
|
||||
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None:
|
||||
"""Test context is prepended for chat prompt when has_context is true."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT)
|
||||
assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None:
|
||||
"""Test chat prompt remains unchanged when has_context is false."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None:
|
||||
"""Test completion app mode with completion model returns completion prompt."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None:
|
||||
"""Test invalid model mode returns empty dict."""
|
||||
# Arrange
|
||||
app_mode = AppMode.CHAT
|
||||
model_mode = "invalid"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
|
||||
def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
|
||||
"""Test helper keeps completion prompt unchanged when context is disabled."""
|
||||
# Arrange
|
||||
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"] == original_text
|
||||
|
||||
def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
|
||||
"""Test helper keeps chat prompt unchanged when context is disabled."""
|
||||
# Arrange
|
||||
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text
|
||||
|
||||
def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None:
|
||||
"""Test baichuan chat/completion returns the expected config."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None:
|
||||
"""Test baichuan completion/chat returns the expected config."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false")
|
||||
|
||||
# Assert
|
||||
assert result == original_config
|
||||
assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None:
|
||||
"""Test baichuan completion/completion prepends baichuan context when enabled."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
|
||||
|
||||
# Assert
|
||||
assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None:
|
||||
"""Test baichuan chat/chat prepends baichuan context when enabled."""
|
||||
# Arrange
|
||||
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
|
||||
|
||||
# Assert
|
||||
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
|
||||
assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
|
||||
|
||||
def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None:
|
||||
"""Test invalid baichuan mode combinations return empty dict."""
|
||||
# Arrange
|
||||
app_mode = "invalid"
|
||||
model_mode = "invalid"
|
||||
|
||||
# Act
|
||||
result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true")
|
||||
|
||||
# Assert
|
||||
assert result == {}
|
||||
346
api/tests/unit_tests/services/test_agent_service.py
Normal file
346
api/tests/unit_tests/services/test_agent_service.py
Normal file
@ -0,0 +1,346 @@
|
||||
"""
|
||||
Unit tests for services.agent_service
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytz
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.agent_service import AgentService
|
||||
|
||||
|
||||
def _make_current_user_account(timezone: str = "UTC") -> Account:
|
||||
account = Account(name="Test User", email="test@example.com")
|
||||
account.timezone = timezone
|
||||
return account
|
||||
|
||||
|
||||
def _make_app_model(app_model_config: MagicMock | None) -> MagicMock:
|
||||
app_model = MagicMock(spec=App)
|
||||
app_model.id = "app-123"
|
||||
app_model.tenant_id = "tenant-123"
|
||||
app_model.app_model_config = app_model_config
|
||||
return app_model
|
||||
|
||||
|
||||
def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock:
|
||||
conversation = MagicMock(spec=Conversation)
|
||||
conversation.id = "conv-123"
|
||||
conversation.app_id = "app-123"
|
||||
conversation.from_end_user_id = from_end_user_id
|
||||
conversation.from_account_id = from_account_id
|
||||
return conversation
|
||||
|
||||
|
||||
def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock:
|
||||
message = MagicMock(spec=Message)
|
||||
message.id = "msg-123"
|
||||
message.conversation_id = "conv-123"
|
||||
message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
|
||||
message.provider_response_latency = 1.23
|
||||
message.answer_tokens = 4
|
||||
message.message_tokens = 6
|
||||
message.agent_thoughts = agent_thoughts
|
||||
message.message_files = ["file-a.txt"]
|
||||
return message
|
||||
|
||||
|
||||
def _make_agent_thought() -> MagicMock:
|
||||
agent_thought = MagicMock(spec=MessageAgentThought)
|
||||
agent_thought.tokens = 3
|
||||
agent_thought.tool_input = "raw-input"
|
||||
agent_thought.observation = "raw-output"
|
||||
agent_thought.thought = "thinking"
|
||||
agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
|
||||
agent_thought.files = []
|
||||
agent_thought.tools = ["tool_a", "dataset_tool"]
|
||||
agent_thought.tool_labels = {"tool_a": "Tool A"}
|
||||
agent_thought.tool_meta = {
|
||||
"tool_a": {
|
||||
"tool_config": {
|
||||
"tool_provider_type": "custom",
|
||||
"tool_provider": "provider-1",
|
||||
},
|
||||
"tool_parameters": {"param": "value"},
|
||||
"time_cost": 2.5,
|
||||
},
|
||||
"dataset_tool": {
|
||||
"tool_config": {
|
||||
"tool_provider_type": "dataset-retrieval",
|
||||
"tool_provider": "dataset-provider",
|
||||
}
|
||||
},
|
||||
}
|
||||
agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}}
|
||||
agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}}
|
||||
return agent_thought
|
||||
|
||||
|
||||
def _build_query_side_effect(
|
||||
conversation: Conversation | None,
|
||||
message: Message | None,
|
||||
executor: EndUser | Account | None,
|
||||
) -> Callable[..., MagicMock]:
|
||||
def _query_side_effect(*args: object, **kwargs: object) -> MagicMock:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
if any(arg is Conversation for arg in args):
|
||||
query.first.return_value = conversation
|
||||
elif any(arg is Message for arg in args):
|
||||
query.first.return_value = message
|
||||
elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args):
|
||||
query.first.return_value = executor
|
||||
return query
|
||||
|
||||
return _query_side_effect
|
||||
|
||||
|
||||
class TestAgentServiceGetAgentLogs:
|
||||
"""Test suite for AgentService.get_agent_logs."""
|
||||
|
||||
def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None:
|
||||
"""Test missing conversation raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(MagicMock())
|
||||
with patch("services.agent_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, "missing-conv", "msg-1")
|
||||
|
||||
def test_get_agent_logs_should_raise_when_message_missing(self) -> None:
|
||||
"""Test missing message raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(MagicMock())
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
with patch("services.agent_service.db") as mock_db:
|
||||
conversation_query = MagicMock()
|
||||
conversation_query.where.return_value = conversation_query
|
||||
conversation_query.first.return_value = conversation
|
||||
|
||||
message_query = MagicMock()
|
||||
message_query.where.return_value = message_query
|
||||
message_query.first.return_value = None
|
||||
|
||||
mock_db.session.query.side_effect = [conversation_query, message_query]
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, "missing-msg")
|
||||
|
||||
def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None:
|
||||
"""Test missing app model config raises ValueError."""
|
||||
# Arrange
|
||||
app_model = _make_app_model(None)
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
message = _make_message([])
|
||||
current_user = _make_current_user_account()
|
||||
|
||||
with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None:
|
||||
"""Test missing agent config raises ValueError."""
|
||||
# Arrange
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
message = _make_message([])
|
||||
current_user = _make_current_user_account()
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=None),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None:
|
||||
"""Test agent logs returned for end-user executor with tool icons."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
executor = MagicMock(spec=EndUser)
|
||||
executor.name = "End User"
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_tool = MagicMock()
|
||||
agent_tool.tool_name = "tool_a"
|
||||
agent_tool.provider_type = "custom"
|
||||
agent_tool.provider_id = "provider-2"
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = [agent_tool]
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert,
|
||||
patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon,
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
|
||||
mock_get_icon.side_effect = [None, "icon-a"]
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["status"] == "success"
|
||||
assert result["meta"]["executor"] == "End User"
|
||||
assert result["meta"]["total_tokens"] == 10
|
||||
assert result["meta"]["agent_mode"] == "react"
|
||||
assert result["meta"]["iterations"] == 1
|
||||
assert result["files"] == ["file-a.txt"]
|
||||
assert len(result["iterations"]) == 1
|
||||
tool_calls = result["iterations"][0]["tool_calls"]
|
||||
assert tool_calls[0]["tool_name"] == "tool_a"
|
||||
assert tool_calls[0]["tool_icon"] == "icon-a"
|
||||
assert tool_calls[1]["tool_name"] == "dataset_tool"
|
||||
assert tool_calls[1]["tool_icon"] == ""
|
||||
mock_convert.assert_called_once()
|
||||
|
||||
def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None:
|
||||
"""Test agent logs fall back to account executor when end user is missing."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1")
|
||||
executor = MagicMock(spec=Account)
|
||||
executor.name = "Account User"
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {"strategy": "react"}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = []
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
|
||||
patch("services.agent_service.ToolManager.get_tool_icon", return_value=""),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["executor"] == "Account User"
|
||||
|
||||
def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None:
|
||||
"""Test unknown executor and missing tool details fall back to defaults."""
|
||||
# Arrange
|
||||
agent_thought = _make_agent_thought()
|
||||
agent_thought.tool_labels = {}
|
||||
agent_thought.tool_inputs_dict = {}
|
||||
agent_thought.tool_outputs_dict = None
|
||||
agent_thought.tool_meta = {"tool_a": {"error": "failed"}}
|
||||
agent_thought.tools = ["tool_a"]
|
||||
|
||||
message = _make_message([agent_thought])
|
||||
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
|
||||
app_model_config = MagicMock()
|
||||
app_model_config.agent_mode_dict = {}
|
||||
app_model_config.to_dict.return_value = {"tools": []}
|
||||
app_model = _make_app_model(app_model_config)
|
||||
current_user = _make_current_user_account()
|
||||
agent_config = MagicMock()
|
||||
agent_config.tools = []
|
||||
|
||||
with (
|
||||
patch("services.agent_service.db") as mock_db,
|
||||
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
|
||||
patch("services.agent_service.ToolManager.get_tool_icon", return_value=None),
|
||||
patch("services.agent_service.current_user", current_user),
|
||||
):
|
||||
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None)
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
|
||||
|
||||
# Assert
|
||||
assert result["meta"]["executor"] == "Unknown"
|
||||
assert result["meta"]["agent_mode"] == "react"
|
||||
tool_call = result["iterations"][0]["tool_calls"][0]
|
||||
assert tool_call["status"] == "error"
|
||||
assert tool_call["error"] == "failed"
|
||||
assert tool_call["tool_label"] == "tool_a"
|
||||
assert tool_call["tool_input"] == {}
|
||||
assert tool_call["tool_output"] == {}
|
||||
assert tool_call["time_cost"] == 0
|
||||
assert tool_call["tool_parameters"] == {}
|
||||
assert tool_call["tool_icon"] is None
|
||||
|
||||
|
||||
class TestAgentServiceProviders:
|
||||
"""Test suite for AgentService provider methods."""
|
||||
|
||||
def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None:
|
||||
"""Test list_agent_providers delegates to PluginAgentClient."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
expected = [{"name": "provider"}]
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_providers.return_value = expected
|
||||
|
||||
# Act
|
||||
result = AgentService.list_agent_providers("user-1", tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id)
|
||||
|
||||
def test_get_agent_provider_should_return_provider_when_successful(self) -> None:
|
||||
"""Test get_agent_provider returns provider when successful."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
provider_name = "provider-a"
|
||||
expected = {"name": provider_name}
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_provider.return_value = expected
|
||||
|
||||
# Act
|
||||
result = AgentService.get_agent_provider("user-1", tenant_id, provider_name)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name)
|
||||
|
||||
def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None:
|
||||
"""Test get_agent_provider wraps PluginDaemonClientSideError into ValueError."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
provider_name = "provider-a"
|
||||
with patch("services.agent_service.PluginAgentClient") as mock_client:
|
||||
mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError(
|
||||
"plugin error"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError):
|
||||
AgentService.get_agent_provider("user-1", tenant_id, provider_name)
|
||||
1685
api/tests/unit_tests/services/test_annotation_service.py
Normal file
1685
api/tests/unit_tests/services/test_annotation_service.py
Normal file
File diff suppressed because it is too large
Load Diff
609
api/tests/unit_tests/services/test_app_service.py
Normal file
609
api/tests/unit_tests/services/test_app_service.py
Normal file
@ -0,0 +1,609 @@
|
||||
"""Unit tests for services.app_service."""
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from models import Account, Tenant
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def service() -> AppService:
|
||||
"""Provide AppService instance."""
|
||||
return AppService()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account() -> Account:
|
||||
"""Create account object for create_app tests."""
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
result = Account(name="Account User", email="account@example.com")
|
||||
result.id = "acc-1"
|
||||
result._current_tenant = tenant
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_args() -> dict:
|
||||
"""Create default create_app args."""
|
||||
return {
|
||||
"name": "Test App",
|
||||
"mode": AppMode.CHAT.value,
|
||||
"icon": "🤖",
|
||||
"icon_background": "#FFFFFF",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_template() -> dict:
|
||||
"""Create basic app template for create_app tests."""
|
||||
return {
|
||||
AppMode.CHAT: {
|
||||
"app": {},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "provider-a",
|
||||
"name": "model-a",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _make_current_user() -> Account:
|
||||
user = Account(name="Tester", email="tester@example.com")
|
||||
user.id = "user-1"
|
||||
tenant = Tenant(name="Tenant")
|
||||
tenant.id = "tenant-1"
|
||||
user._current_tenant = tenant
|
||||
return user
|
||||
|
||||
|
||||
class TestAppServicePagination:
|
||||
"""Test suite for get_paginate_apps."""
|
||||
|
||||
def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None:
|
||||
"""Test pagination returns None when tag filter has no targets."""
|
||||
# Arrange
|
||||
args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]}
|
||||
|
||||
with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]):
|
||||
# Act
|
||||
result = service.get_paginate_apps("user-1", "tenant-1", args)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None:
|
||||
"""Test pagination delegates to db.paginate when filters are valid."""
|
||||
# Arrange
|
||||
args = {
|
||||
"mode": "workflow",
|
||||
"is_created_by_me": True,
|
||||
"name": "My_App%",
|
||||
"tag_ids": ["tag-1"],
|
||||
"page": 2,
|
||||
"limit": 10,
|
||||
}
|
||||
expected_pagination = MagicMock()
|
||||
|
||||
with (
|
||||
patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]),
|
||||
patch("libs.helper.escape_like_pattern", return_value="escaped"),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
mock_db.paginate.return_value = expected_pagination
|
||||
|
||||
# Act
|
||||
result = service.get_paginate_apps("user-1", "tenant-1", args)
|
||||
|
||||
# Assert
|
||||
assert result is expected_pagination
|
||||
mock_db.paginate.assert_called_once()
|
||||
|
||||
|
||||
class TestAppServiceCreate:
|
||||
"""Test suite for create_app."""
|
||||
|
||||
def test_create_app_should_create_with_matching_default_model(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app uses matching default model and persists app config."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
model_instance = SimpleNamespace(
|
||||
model_name="model-a",
|
||||
provider="provider-a",
|
||||
model_type_instance=MagicMock(),
|
||||
credentials={"k": "v"},
|
||||
)
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.app_was_created") as mock_event,
|
||||
patch("services.app_service.FeatureService.get_system_features") as mock_features,
|
||||
patch("services.app_service.BillingService") as mock_billing,
|
||||
patch("services.app_service.dify_config") as mock_config,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
|
||||
mock_config.BILLING_ENABLED = True
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
assert app_instance.app_model_config_id == "cfg-1"
|
||||
mock_db.session.add.assert_any_call(app_instance)
|
||||
mock_db.session.add.assert_any_call(app_model_config)
|
||||
assert mock_db.session.flush.call_count == 2
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_event.send.assert_called_once_with(app_instance, account=account)
|
||||
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
|
||||
|
||||
def test_create_app_should_raise_when_model_schema_missing(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app raises ValueError when non-matching model has no schema."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1")
|
||||
model_instance = SimpleNamespace(
|
||||
model_name="model-b",
|
||||
provider="provider-b",
|
||||
model_type_instance=MagicMock(),
|
||||
credentials={"k": "v"},
|
||||
)
|
||||
model_instance.model_type_instance.get_model_schema.return_value = None
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.return_value = model_instance
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="model schema not found"):
|
||||
service.create_app("tenant-1", default_args, account)
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
def test_create_app_should_fallback_to_default_provider_when_model_missing(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test create_app falls back to provider/model name when no default model instance is available."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.app_was_created") as mock_event,
|
||||
patch("services.app_service.FeatureService.get_system_features") as mock_features,
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise,
|
||||
patch("services.app_service.dify_config") as mock_config,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready")
|
||||
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
|
||||
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
|
||||
mock_config.BILLING_ENABLED = False
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
mock_event.send.assert_called_once_with(app_instance, account=account)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private")
|
||||
|
||||
def test_create_app_should_log_and_fallback_on_unexpected_model_error(
|
||||
self,
|
||||
service: AppService,
|
||||
account: Account,
|
||||
default_args: dict,
|
||||
app_template: dict,
|
||||
) -> None:
|
||||
"""Test unexpected model manager errors are logged and fallback provider is used."""
|
||||
# Arrange
|
||||
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
app_model_config = SimpleNamespace(id="cfg-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.default_app_templates", app_template),
|
||||
patch("services.app_service.App", return_value=app_instance),
|
||||
patch("services.app_service.AppModelConfig", return_value=app_model_config),
|
||||
patch("services.app_service.ModelManager") as mock_model_manager,
|
||||
patch("services.app_service.db"),
|
||||
patch("services.app_service.app_was_created"),
|
||||
patch(
|
||||
"services.app_service.FeatureService.get_system_features",
|
||||
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
|
||||
),
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)),
|
||||
patch("services.app_service.logger") as mock_logger,
|
||||
):
|
||||
manager = mock_model_manager.return_value
|
||||
manager.get_default_model_instance.side_effect = RuntimeError("boom")
|
||||
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
|
||||
|
||||
# Act
|
||||
result = service.create_app("tenant-1", default_args, account)
|
||||
|
||||
# Assert
|
||||
assert result is app_instance
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
|
||||
class TestAppServiceGetAndUpdate:
|
||||
"""Test suite for app retrieval and update methods."""
|
||||
|
||||
def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None:
|
||||
"""Test get_app returns original app for non-agent modes."""
|
||||
# Arrange
|
||||
app = MagicMock()
|
||||
app.mode = AppMode.CHAT
|
||||
app.is_agent = False
|
||||
|
||||
with patch("services.app_service.current_user", _make_current_user()):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result is app
|
||||
|
||||
def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None:
|
||||
"""Test get_app returns app when agent mode has no model config."""
|
||||
# Arrange
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = None
|
||||
|
||||
with patch("services.app_service.current_user", _make_current_user()):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result is app
|
||||
|
||||
def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None:
|
||||
"""Test get_app decrypts and masks secret tool parameters."""
|
||||
# Arrange
|
||||
tool = {
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "provider-1",
|
||||
"tool_name": "tool-a",
|
||||
"tool_parameters": {"secret": "encrypted"},
|
||||
"extra": True,
|
||||
}
|
||||
model_config = MagicMock()
|
||||
model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]}
|
||||
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = model_config
|
||||
|
||||
manager = MagicMock()
|
||||
manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"}
|
||||
manager.mask_tool_parameters.return_value = {"secret": "***"}
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", _make_current_user()),
|
||||
patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()),
|
||||
patch("services.app_service.ToolParameterConfigurationManager", return_value=manager),
|
||||
):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result.app_model_config is model_config
|
||||
assert tool["tool_parameters"] == {"secret": "***"}
|
||||
assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"}
|
||||
|
||||
def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None:
|
||||
"""Test get_app logs and continues when masking fails."""
|
||||
# Arrange
|
||||
tool = {
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "provider-1",
|
||||
"tool_name": "tool-a",
|
||||
"tool_parameters": {"secret": "encrypted"},
|
||||
"extra": True,
|
||||
}
|
||||
model_config = MagicMock()
|
||||
model_config.agent_mode_dict = {"tools": [tool]}
|
||||
|
||||
app = MagicMock()
|
||||
app.id = "app-1"
|
||||
app.mode = AppMode.AGENT_CHAT
|
||||
app.is_agent = False
|
||||
app.app_model_config = model_config
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", _make_current_user()),
|
||||
patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")),
|
||||
patch("services.app_service.logger") as mock_logger,
|
||||
):
|
||||
# Act
|
||||
result = service.get_app(app)
|
||||
|
||||
# Assert
|
||||
assert result.app_model_config is model_config
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None:
|
||||
"""Test update methods set fields and commit changes."""
|
||||
# Arrange
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
name="old",
|
||||
description="old",
|
||||
icon_type="emoji",
|
||||
icon="a",
|
||||
icon_background="#111",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
),
|
||||
)
|
||||
args = {
|
||||
"name": "new",
|
||||
"description": "new-desc",
|
||||
"icon_type": "image",
|
||||
"icon": "new-icon",
|
||||
"icon_background": "#222",
|
||||
"use_icon_as_answer_icon": True,
|
||||
"max_active_requests": 5,
|
||||
}
|
||||
user = SimpleNamespace(id="user-1")
|
||||
|
||||
with (
|
||||
patch("services.app_service.current_user", user),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch("services.app_service.naive_utc_now", return_value="now"),
|
||||
):
|
||||
# Act
|
||||
updated = service.update_app(app, args)
|
||||
renamed = service.update_app_name(app, "rename")
|
||||
iconed = service.update_app_icon(app, "icon-2", "#333")
|
||||
site_same = service.update_app_site_status(app, app.enable_site)
|
||||
api_same = service.update_app_api_status(app, app.enable_api)
|
||||
site_changed = service.update_app_site_status(app, False)
|
||||
api_changed = service.update_app_api_status(app, False)
|
||||
|
||||
# Assert
|
||||
assert updated is app
|
||||
assert renamed is app
|
||||
assert iconed is app
|
||||
assert site_same is app
|
||||
assert api_same is app
|
||||
assert site_changed is app
|
||||
assert api_changed is app
|
||||
assert mock_db.session.commit.call_count >= 5
|
||||
|
||||
|
||||
class TestAppServiceDeleteAndMeta:
|
||||
"""Test suite for delete and metadata methods."""
|
||||
|
||||
def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None:
|
||||
"""Test delete_app removes app, runs cleanup, and triggers async deletion task."""
|
||||
# Arrange
|
||||
app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
|
||||
|
||||
with (
|
||||
patch("services.app_service.db") as mock_db,
|
||||
patch(
|
||||
"services.app_service.FeatureService.get_system_features",
|
||||
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)),
|
||||
),
|
||||
patch("services.app_service.EnterpriseService") as mock_enterprise,
|
||||
patch(
|
||||
"services.app_service.dify_config",
|
||||
new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"),
|
||||
),
|
||||
patch("services.app_service.BillingService") as mock_billing,
|
||||
patch("services.app_service.remove_app_and_related_data_task") as mock_task,
|
||||
):
|
||||
# Act
|
||||
service.delete_app(app)
|
||||
|
||||
# Assert
|
||||
mock_db.session.delete.assert_called_once_with(app)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1")
|
||||
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
|
||||
mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1")
|
||||
|
||||
def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None:
|
||||
"""Test get_app_meta extracts builtin and API tool icons from workflow graph."""
|
||||
# Arrange
|
||||
workflow = SimpleNamespace(
|
||||
graph_dict={
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"provider_type": "builtin",
|
||||
"provider_id": "builtin-provider",
|
||||
"tool_name": "tool_builtin",
|
||||
}
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "tool",
|
||||
"provider_type": "api",
|
||||
"provider_id": "api-provider-id",
|
||||
"tool_name": "tool_api",
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
app = cast(
|
||||
App,
|
||||
SimpleNamespace(
|
||||
mode=AppMode.WORKFLOW.value,
|
||||
workflow=workflow,
|
||||
app_model_config=None,
|
||||
tenant_id="tenant-1",
|
||||
icon_type="emoji",
|
||||
icon_background="#fff",
|
||||
),
|
||||
)
|
||||
|
||||
provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"}))
|
||||
|
||||
with (
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = provider
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
meta = service.get_app_meta(app)
|
||||
|
||||
# Assert
|
||||
assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon")
|
||||
assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"}
|
||||
|
||||
def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None:
|
||||
"""Test get_app_meta falls back to default icon when API provider lookup fails."""
|
||||
# Arrange
|
||||
app_model_config = SimpleNamespace(
|
||||
agent_mode_dict={
|
||||
"tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}]
|
||||
}
|
||||
)
|
||||
app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None))
|
||||
|
||||
with (
|
||||
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
|
||||
patch("services.app_service.db") as mock_db,
|
||||
):
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
meta = service.get_app_meta(app)
|
||||
|
||||
# Assert
|
||||
assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"}
|
||||
|
||||
def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None:
|
||||
"""Test get_app_meta returns empty metadata when workflow/model config is absent."""
|
||||
# Arrange
|
||||
workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None))
|
||||
chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None))
|
||||
|
||||
# Act
|
||||
workflow_meta = service.get_app_meta(workflow_app)
|
||||
chat_meta = service.get_app_meta(chat_app)
|
||||
|
||||
# Assert
|
||||
assert workflow_meta == {"tool_icons": {}}
|
||||
assert chat_meta == {"tool_icons": {}}
|
||||
|
||||
|
||||
class TestAppServiceCodeLookup:
|
||||
"""Test suite for app code lookup methods."""
|
||||
|
||||
def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None:
|
||||
"""Test get_app_code_by_id raises when site is missing."""
|
||||
# Arrange
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_code_by_id("app-1")
|
||||
|
||||
def test_get_app_code_by_id_should_return_code(self) -> None:
|
||||
"""Test get_app_code_by_id returns site code."""
|
||||
# Arrange
|
||||
site = SimpleNamespace(code="code-1")
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = site
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
result = AppService.get_app_code_by_id("app-1")
|
||||
|
||||
# Assert
|
||||
assert result == "code-1"
|
||||
|
||||
def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None:
|
||||
"""Test get_app_id_by_code raises when code does not exist."""
|
||||
# Arrange
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = None
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
AppService.get_app_id_by_code("missing")
|
||||
|
||||
def test_get_app_id_by_code_should_return_app_id(self) -> None:
|
||||
"""Test get_app_id_by_code returns linked app id."""
|
||||
# Arrange
|
||||
site = SimpleNamespace(app_id="app-1")
|
||||
with patch("services.app_service.db") as mock_db:
|
||||
query = MagicMock()
|
||||
query.where.return_value = query
|
||||
query.first.return_value = site
|
||||
mock_db.session.query.return_value = query
|
||||
|
||||
# Act
|
||||
result = AppService.get_app_id_by_code("code-1")
|
||||
|
||||
# Assert
|
||||
assert result == "app-1"
|
||||
387
api/tests/unit_tests/services/test_batch_indexing_base.py
Normal file
387
api/tests/unit_tests/services/test_batch_indexing_base.py
Normal file
@ -0,0 +1,387 @@
|
||||
from dataclasses import asdict
|
||||
from typing import Any, ClassVar, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.document_task import DocumentTask
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Concrete subclass for testing (the base class is abstract)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ConcreteBatchProxy(BatchDocumentIndexingProxy):
|
||||
"""Minimal concrete implementation that provides the required class-level vars."""
|
||||
|
||||
QUEUE_NAME: ClassVar[str] = "test_queue"
|
||||
NORMAL_TASK_FUNC: ClassVar[Any] = MagicMock(name="NORMAL_TASK_FUNC")
|
||||
PRIORITY_TASK_FUNC: ClassVar[Any] = MagicMock(name="PRIORITY_TASK_FUNC")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TENANT_ID = "tenant-abc"
|
||||
DATASET_ID = "dataset-xyz"
|
||||
DOC_IDS: list[str] = ["doc-1", "doc-2", "doc-3"]
|
||||
|
||||
|
||||
def make_proxy(**kwargs: Any) -> ConcreteBatchProxy:
|
||||
"""Factory: returns a ConcreteBatchProxy with TenantIsolatedTaskQueue mocked out."""
|
||||
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue:
|
||||
proxy = ConcreteBatchProxy(
|
||||
tenant_id=kwargs.get("tenant_id", TENANT_ID),
|
||||
dataset_id=kwargs.get("dataset_id", DATASET_ID),
|
||||
document_ids=kwargs.get("document_ids", DOC_IDS),
|
||||
)
|
||||
# Expose the mock queue on the proxy so tests can assert on it
|
||||
proxy._tenant_isolated_task_queue = MockQueue.return_value
|
||||
return proxy
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test suite
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchDocumentIndexingProxyInit:
|
||||
"""Tests for __init__ of BatchDocumentIndexingProxy."""
|
||||
|
||||
def test_should_store_document_ids_when_initialized(self) -> None:
|
||||
"""Verify that document_ids are stored on the proxy instance."""
|
||||
# Arrange
|
||||
doc_ids: list[str] = ["doc-a", "doc-b"]
|
||||
|
||||
# Act
|
||||
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
|
||||
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids)
|
||||
|
||||
# Assert
|
||||
assert proxy._document_ids == doc_ids
|
||||
|
||||
def test_should_propagate_tenant_and_dataset_to_base_when_initialized(self) -> None:
|
||||
"""Verify that tenant_id and dataset_id are forwarded to the parent class."""
|
||||
# Arrange / Act
|
||||
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
|
||||
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS)
|
||||
|
||||
# Assert
|
||||
assert proxy._tenant_id == TENANT_ID
|
||||
assert proxy._dataset_id == DATASET_ID
|
||||
|
||||
def test_should_create_tenant_isolated_queue_with_correct_args_when_initialized(self) -> None:
|
||||
"""Verify that TenantIsolatedTaskQueue is constructed with (tenant_id, QUEUE_NAME)."""
|
||||
# Arrange / Act
|
||||
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue:
|
||||
ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS)
|
||||
|
||||
# Assert
|
||||
MockQueue.assert_called_once_with(TENANT_ID, ConcreteBatchProxy.QUEUE_NAME)
|
||||
|
||||
@pytest.mark.parametrize("doc_ids", [[], ["single-doc"], ["d1", "d2", "d3", "d4"]])
|
||||
def test_should_accept_any_length_document_ids_when_initialized(self, doc_ids: list[str]) -> None:
|
||||
"""Verify that empty, single, and multiple document IDs are all accepted."""
|
||||
# Arrange / Act
|
||||
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
|
||||
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids)
|
||||
|
||||
# Assert
|
||||
assert list(proxy._document_ids) == doc_ids
|
||||
|
||||
|
||||
class TestSendToDirectQueue:
|
||||
"""Tests for _send_to_direct_queue."""
|
||||
|
||||
def test_should_call_task_func_delay_with_correct_args_when_sent_to_direct_queue(
|
||||
self,
|
||||
) -> None:
|
||||
"""Verify that task_func.delay is called with the right kwargs."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(task_func)
|
||||
|
||||
# Assert
|
||||
task_func.delay.assert_called_once_with(
|
||||
tenant_id=TENANT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
document_ids=DOC_IDS,
|
||||
)
|
||||
|
||||
def test_should_not_interact_with_tenant_queue_when_sent_to_direct_queue(self) -> None:
|
||||
"""Direct queue path must never touch the tenant-isolated queue."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
mock_queue.push_tasks.assert_not_called()
|
||||
mock_queue.set_task_waiting_time.assert_not_called()
|
||||
|
||||
def test_should_forward_any_callable_when_sent_to_direct_queue(self) -> None:
|
||||
"""Verify that different task functions are each called correctly."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
task_a, task_b = MagicMock(), MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_direct_queue(task_a)
|
||||
proxy._send_to_direct_queue(task_b)
|
||||
|
||||
# Assert
|
||||
task_a.delay.assert_called_once()
|
||||
task_b.delay.assert_called_once()
|
||||
|
||||
|
||||
class TestSendToTenantQueue:
|
||||
"""Tests for _send_to_tenant_queue — both branches."""
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Branch 1: get_task_key() is truthy → push to waiting queue
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_should_push_task_to_queue_when_task_key_exists(self) -> None:
|
||||
"""When get_task_key() is truthy, tasks must be pushed via push_tasks()."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
expected_payload = [asdict(DocumentTask(tenant_id=TENANT_ID, dataset_id=DATASET_ID, document_ids=DOC_IDS))]
|
||||
mock_queue.push_tasks.assert_called_once_with(expected_payload)
|
||||
|
||||
def test_should_not_call_task_func_delay_when_task_key_exists(self) -> None:
|
||||
"""When a key already exists, task_func.delay must never be called."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
cast(MagicMock, task_func.delay).assert_not_called()
|
||||
|
||||
def test_should_not_set_waiting_time_when_task_key_exists(self) -> None:
|
||||
"""When a key already exists, set_task_waiting_time must never be called."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
mock_queue.set_task_waiting_time.assert_not_called()
|
||||
|
||||
def test_should_serialize_document_task_correctly_when_pushing_to_queue(self) -> None:
|
||||
"""Verify the serialised payload matches asdict(DocumentTask(...))."""
|
||||
# Arrange
|
||||
proxy = make_proxy(document_ids=["doc-x"])
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = "k"
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert — inspect the payload passed to push_tasks
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
call_args = mock_queue.push_tasks.call_args
|
||||
pushed_list = call_args[0][0] # first positional arg
|
||||
assert len(pushed_list) == 1
|
||||
assert pushed_list[0]["tenant_id"] == TENANT_ID
|
||||
assert pushed_list[0]["dataset_id"] == DATASET_ID
|
||||
assert pushed_list[0]["document_ids"] == ["doc-x"]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Branch 2: get_task_key() is falsy → set flag + dispatch via delay
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def test_should_set_waiting_time_and_call_delay_when_no_task_key(self) -> None:
|
||||
"""When get_task_key() is falsy, set_task_waiting_time and task_func.delay are invoked."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
mock_queue.set_task_waiting_time.assert_called_once()
|
||||
cast(MagicMock, task_func.delay).assert_called_once_with(
|
||||
tenant_id=TENANT_ID,
|
||||
dataset_id=DATASET_ID,
|
||||
document_ids=DOC_IDS,
|
||||
)
|
||||
|
||||
def test_should_not_push_tasks_when_no_task_key(self) -> None:
|
||||
"""When get_task_key() is falsy, push_tasks must never be called."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
mock_queue.push_tasks.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize("falsy_key", [None, "", 0, False])
|
||||
def test_should_init_task_when_key_is_any_falsy_value(self, falsy_key: Any) -> None:
|
||||
"""Verify that any falsy return from get_task_key() triggers the init branch."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = falsy_key
|
||||
task_func = MagicMock()
|
||||
|
||||
# Act
|
||||
proxy._send_to_tenant_queue(task_func)
|
||||
|
||||
# Assert
|
||||
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
|
||||
mock_queue.set_task_waiting_time.assert_called_once()
|
||||
cast(MagicMock, task_func.delay).assert_called_once()
|
||||
|
||||
|
||||
class TestDispatchRouting:
|
||||
"""Tests for the _dispatch / delay routing logic inherited from the base class."""
|
||||
|
||||
def _mock_features(self, enabled: bool, plan: CloudPlan) -> MagicMock:
|
||||
features = MagicMock()
|
||||
features.billing.enabled = enabled
|
||||
features.billing.subscription.plan = plan
|
||||
return features
|
||||
|
||||
def test_should_send_to_normal_tenant_queue_when_billing_enabled_and_sandbox_plan(self) -> None:
|
||||
"""Sandbox plan routes to normal priority queue with tenant isolation."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
|
||||
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.SANDBOX)
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_default_tenant_queue") as mock_method:
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once()
|
||||
|
||||
def test_should_send_to_priority_tenant_queue_when_billing_enabled_and_paid_plan(self) -> None:
|
||||
"""Non-sandbox paid plan routes to priority queue with tenant isolation."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.PROFESSIONAL)
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_priority_tenant_queue") as mock_method:
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once()
|
||||
|
||||
def test_should_send_to_priority_direct_queue_when_billing_not_enabled(self) -> None:
|
||||
"""Self-hosted / no billing → priority direct queue (no tenant isolation)."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX)
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_priority_direct_queue") as mock_method:
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once()
|
||||
|
||||
def test_should_call_dispatch_when_delay_is_invoked(self) -> None:
|
||||
"""Calling delay() must invoke _dispatch() exactly once."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_dispatch") as mock_dispatch:
|
||||
proxy.delay()
|
||||
|
||||
# Assert
|
||||
mock_dispatch.assert_called_once()
|
||||
|
||||
def test_should_use_feature_service_for_billing_info(self) -> None:
|
||||
"""Verify that FeatureService.get_features is consulted during dispatch."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
|
||||
mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX)
|
||||
with patch.object(proxy, "_send_to_priority_direct_queue"):
|
||||
# Act
|
||||
proxy._dispatch()
|
||||
|
||||
# Assert
|
||||
mock_features.assert_called_once_with(TENANT_ID)
|
||||
|
||||
|
||||
class TestBaseRouterHelpers:
|
||||
"""Tests for the three routing helper methods from the base class."""
|
||||
|
||||
def test_should_call_send_to_tenant_queue_with_normal_func_when_default_tenant_queue(self) -> None:
|
||||
"""_send_to_default_tenant_queue must forward NORMAL_TASK_FUNC."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_tenant_queue") as mock_method:
|
||||
proxy._send_to_default_tenant_queue()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once_with(ConcreteBatchProxy.NORMAL_TASK_FUNC)
|
||||
|
||||
def test_should_call_send_to_tenant_queue_with_priority_func_when_priority_tenant_queue(self) -> None:
|
||||
"""_send_to_priority_tenant_queue must forward PRIORITY_TASK_FUNC."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_tenant_queue") as mock_method:
|
||||
proxy._send_to_priority_tenant_queue()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC)
|
||||
|
||||
def test_should_call_send_to_direct_queue_with_priority_func_when_priority_direct_queue(self) -> None:
|
||||
"""_send_to_priority_direct_queue must forward PRIORITY_TASK_FUNC."""
|
||||
# Arrange
|
||||
proxy = make_proxy()
|
||||
|
||||
# Act
|
||||
with patch.object(proxy, "_send_to_direct_queue") as mock_method:
|
||||
proxy._send_to_priority_direct_queue()
|
||||
|
||||
# Assert
|
||||
mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC)
|
||||
@ -0,0 +1,760 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.entities.plugin_daemon import CredentialType
|
||||
from dify_graph.model_runtime.entities.provider_entities import FormType
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.oauth import DatasourceProvider
|
||||
from models.provider_ids import DatasourceProviderID
|
||||
from services.datasource_provider_service import DatasourceProviderService, get_current_user
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID:
|
||||
return DatasourceProviderID(s)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test class
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasourceProviderService:
|
||||
"""Comprehensive tests for DatasourceProviderService targeting >95% coverage."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
return DatasourceProviderService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session(self):
|
||||
"""
|
||||
Robust, chainable query mock.
|
||||
q returns itself for .filter_by(), .order_by(), .where() so any
|
||||
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
|
||||
"""
|
||||
with patch("services.datasource_provider_service.Session") as mock_cls:
|
||||
sess = MagicMock(spec=Session)
|
||||
|
||||
q = MagicMock()
|
||||
sess.query.return_value = q
|
||||
|
||||
# Self-returning chain — any method called on q returns q
|
||||
q.filter_by.return_value = q
|
||||
q.order_by.return_value = q
|
||||
q.where.return_value = q
|
||||
|
||||
# Default terminal values (tests override per-case)
|
||||
q.first.return_value = None
|
||||
q.all.return_value = []
|
||||
q.count.return_value = 0
|
||||
q.delete.return_value = 1
|
||||
|
||||
mock_cls.return_value.__enter__.return_value = sess
|
||||
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
|
||||
|
||||
yield sess
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_db(self, mock_db_session):
|
||||
with patch("services.datasource_provider_service.db") as mock_db:
|
||||
mock_db.session = mock_db_session
|
||||
mock_db.engine = MagicMock()
|
||||
yield mock_db
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def patch_externals(self):
|
||||
with (
|
||||
patch("httpx.request") as mock_httpx,
|
||||
patch("services.datasource_provider_service.dify_config") as mock_cfg,
|
||||
patch("services.datasource_provider_service.encrypter") as mock_enc,
|
||||
patch("services.datasource_provider_service.redis_client") as mock_redis,
|
||||
patch("services.datasource_provider_service.generate_incremental_name") as mock_genname,
|
||||
patch("services.datasource_provider_service.OAuthHandler") as mock_oauth,
|
||||
):
|
||||
mock_cfg.CONSOLE_API_URL = "http://localhost"
|
||||
mock_enc.encrypt_token.return_value = "enc_tok"
|
||||
mock_enc.decrypt_token.return_value = "dec_tok"
|
||||
mock_enc.decrypt.return_value = {"k": "dec"}
|
||||
mock_enc.encrypt.return_value = {"k": "enc"}
|
||||
mock_enc.obfuscated_token.return_value = "obf"
|
||||
mock_enc.mask_plugin_credentials.return_value = {"k": "mask"}
|
||||
|
||||
mock_redis.lock.return_value.__enter__.return_value = MagicMock()
|
||||
mock_genname.return_value = "gen_name"
|
||||
|
||||
mock_oauth.return_value.refresh_credentials.return_value = MagicMock(
|
||||
credentials={"k": "v"}, expires_at=9999
|
||||
)
|
||||
|
||||
resp = MagicMock()
|
||||
resp.status_code = 200
|
||||
resp.json.return_value = {
|
||||
"code": 0,
|
||||
"message": "ok",
|
||||
"data": {
|
||||
"provider": "prov",
|
||||
"plugin_unique_identifier": "pui",
|
||||
"plugin_id": "org/plug",
|
||||
"is_authorized": False,
|
||||
"declaration": {
|
||||
"identity": {
|
||||
"author": "a",
|
||||
"name": "n",
|
||||
"description": {"en_US": "d"},
|
||||
"icon": "i",
|
||||
"label": {"en_US": "l"},
|
||||
},
|
||||
"credentials_schema": [],
|
||||
"oauth_schema": {"credentials_schema": [], "client_schema": []},
|
||||
"provider_type": "local_file",
|
||||
"datasources": [],
|
||||
},
|
||||
},
|
||||
}
|
||||
mock_httpx.return_value = resp
|
||||
|
||||
# Store handles for assertions
|
||||
self._enc = mock_enc
|
||||
self._redis = mock_redis
|
||||
yield
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
u = MagicMock()
|
||||
u.id = "uid-1"
|
||||
return u
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_current_user (lines 27-40)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_proxy_when_current_object_is_account(self):
|
||||
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
||||
user_obj = MagicMock()
|
||||
user_obj.__class__ = Account
|
||||
proxy._get_current_object.return_value = user_obj
|
||||
assert get_current_user() is proxy
|
||||
|
||||
def test_should_return_proxy_when_current_object_is_enduser(self):
|
||||
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
||||
user_obj = MagicMock()
|
||||
user_obj.__class__ = EndUser
|
||||
proxy._get_current_object.return_value = user_obj
|
||||
assert get_current_user() is proxy
|
||||
|
||||
def test_should_return_proxy_when_get_current_object_raises_attribute_error(self):
|
||||
"""AttributeError from LocalProxy falls back to the proxy itself."""
|
||||
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
||||
proxy._get_current_object.side_effect = AttributeError("no attr")
|
||||
proxy.__class__ = Account # make the proxy itself satisfy isinstance
|
||||
assert get_current_user() is proxy
|
||||
|
||||
def test_should_raise_type_error_when_user_is_not_account_or_enduser(self):
|
||||
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
|
||||
proxy._get_current_object.return_value = "plain_string"
|
||||
with pytest.raises(TypeError, match="current_user must be Account or EndUser"):
|
||||
get_current_user()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# is_system_oauth_params_exist (line 357-363)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock()
|
||||
assert service.is_system_oauth_params_exist(make_id()) is True
|
||||
|
||||
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
assert service.is_system_oauth_params_exist(make_id()) is False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# is_tenant_oauth_params_enabled (lines 365-379)
|
||||
# NOTE: uses .count() not .first()
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 1
|
||||
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True
|
||||
|
||||
def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_oauth_custom_client_params (lines 55-61)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
|
||||
service.remove_oauth_custom_client_params("t1", make_id())
|
||||
mock_db_session.query().delete.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# setup_oauth_custom_client_params (315-351)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session):
|
||||
"""When credentials=None, should return immediately without any DB write."""
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), None, None)
|
||||
mock_db_session.add.assert_not_called()
|
||||
|
||||
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
|
||||
existing = MagicMock()
|
||||
mock_db_session.query().first.return_value = existing
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
|
||||
mock_db_session.add.assert_not_called() # update in place, no add
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# decrypt / encrypt credentials (lines 70-98)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "enc_val"}
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov")
|
||||
assert result["sk"] == "dec_tok"
|
||||
|
||||
def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p)
|
||||
assert result["sk"] == "enc_tok"
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_datasource_credentials (lines 113-165)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
|
||||
|
||||
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
|
||||
"""Expired OAuth credential (expires_at near zero) triggers a silent refresh."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0 # expired
|
||||
p.encrypted_credentials = {"tok": "x"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
|
||||
):
|
||||
service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
|
||||
"""API key credentials with expires_at=-1 skip refresh and return directly."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1 # sentinel: never expires
|
||||
p.encrypted_credentials = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
|
||||
):
|
||||
result = service.get_datasource_credentials("t1", "prov", "org/plug")
|
||||
assert result == {"k": "plain"}
|
||||
|
||||
def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user):
|
||||
"""When credential_id is passed, the credential_id filter path (line 113) is taken."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.expires_at = -1
|
||||
p.encrypted_credentials = {}
|
||||
mock_db_session.query().first.return_value = p
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
|
||||
):
|
||||
result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id")
|
||||
assert result == {"k": "v"}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_all_datasource_credentials_by_provider (lines 176-228)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
mock_db_session.query().all.return_value = []
|
||||
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "oauth2"
|
||||
p.expires_at = 0
|
||||
p.encrypted_credentials = {"t": "x"}
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
|
||||
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
|
||||
):
|
||||
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# update_datasource_provider_name (lines 236-303)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
|
||||
|
||||
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "same"
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
|
||||
mock_db_session.commit.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
|
||||
def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.is_default = False
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
|
||||
assert p.name == "new_name"
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# set_default_datasource_provider (lines 277-303)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.set_default_datasource_provider("t1", make_id(), "bad-id")
|
||||
|
||||
def test_should_mark_target_as_default_and_commit(self, service, mock_db_session):
|
||||
target = MagicMock(spec=DatasourceProvider)
|
||||
target.provider = "provider"
|
||||
target.plugin_id = "org/plug"
|
||||
mock_db_session.query().first.return_value = target
|
||||
service.set_default_datasource_provider("t1", make_id(), "new-id")
|
||||
assert target.is_default is True
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_encrypter (lines 404-420)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_oauth_schema_missing(self, service):
|
||||
pm = MagicMock()
|
||||
pm.declaration.oauth_schema = None
|
||||
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
||||
with pytest.raises(ValueError, match="oauth schema not found"):
|
||||
service.get_oauth_encrypter("t1", make_id())
|
||||
|
||||
def test_should_return_encrypter_when_oauth_schema_exists(self, service):
|
||||
schema_item = MagicMock()
|
||||
schema_item.to_basic_provider_config.return_value = MagicMock()
|
||||
pm = MagicMock()
|
||||
pm.declaration.oauth_schema.client_schema = [schema_item]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm),
|
||||
patch(
|
||||
"services.datasource_provider_service.create_provider_encrypter",
|
||||
return_value=(MagicMock(), MagicMock()),
|
||||
),
|
||||
):
|
||||
result = service.get_oauth_encrypter("t1", make_id())
|
||||
assert result is not None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_tenant_oauth_client (lines 381-402)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session):
|
||||
tenant_params = MagicMock()
|
||||
tenant_params.client_params = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = tenant_params
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_tenant_oauth_client("t1", make_id(), mask=True)
|
||||
assert result == {"k": "mask"}
|
||||
|
||||
def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session):
|
||||
tenant_params = MagicMock()
|
||||
tenant_params.client_params = {"k": "v"}
|
||||
mock_db_session.query().first.return_value = tenant_params
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_tenant_oauth_client("t1", make_id(), mask=False)
|
||||
assert result == {"k": "dec"}
|
||||
|
||||
def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
assert service.get_tenant_oauth_client("t1", make_id()) is None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_oauth_client (lines 423-457)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
|
||||
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
|
||||
result = service.get_oauth_client("t1", make_id())
|
||||
assert result == {"k": "dec"}
|
||||
|
||||
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
|
||||
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
|
||||
):
|
||||
result = service.get_oauth_client("t1", make_id())
|
||||
assert result == {"k": "sys"}
|
||||
|
||||
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
|
||||
"""Neither tenant nor system credentials → raises ValueError."""
|
||||
mock_db_session.query().first.side_effect = [None, None]
|
||||
with (
|
||||
patch.object(service.provider_manager, "fetch_datasource_provider"),
|
||||
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Please configure oauth client params"):
|
||||
service.get_oauth_client("t1", make_id())
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# add_datasource_oauth_provider (lines 539-607)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
|
||||
"""Conflict on name results in auto-incremented name, not an error."""
|
||||
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
|
||||
mock_db_session.query().all.return_value = []
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
|
||||
):
|
||||
service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
|
||||
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
|
||||
mock_db_session.query().count.return_value = 0
|
||||
mock_db_session.query().all.return_value = []
|
||||
with (
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
|
||||
):
|
||||
service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {})
|
||||
mock_db_session.add.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
|
||||
self._redis.lock.assert_called()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# reauthorize_datasource_oauth_provider (lines 477-537)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
|
||||
mock_db_session.query().first.return_value = None
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
|
||||
|
||||
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1 # conflict
|
||||
mock_db_session.query().all.return_value = []
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(
|
||||
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
|
||||
)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
|
||||
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
|
||||
self._enc.encrypt_token.assert_called()
|
||||
|
||||
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with patch.object(service, "extract_secret_variables", return_value=[]):
|
||||
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
|
||||
self._redis.lock.assert_called()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# add_datasource_api_key_provider (lines 608-675)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
|
||||
"""explicit name supplied + conflict → raises ValueError immediately."""
|
||||
mock_db_session.query().count.return_value = 1
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
|
||||
|
||||
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Failed to validate"):
|
||||
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
|
||||
|
||||
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
|
||||
mock_db_session.add.assert_called_once()
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
patch.object(service, "extract_secret_variables", return_value=[]),
|
||||
):
|
||||
service.add_datasource_api_key_provider(None, "t1", make_id(), {})
|
||||
self._redis.lock.assert_called()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# extract_secret_variables (lines 666-699)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_extract_secret_variable_names_for_api_key_schema(self, service):
|
||||
schema = MagicMock()
|
||||
schema.name = "my_secret"
|
||||
schema.type = MagicMock()
|
||||
schema.type.value = FormType.SECRET_INPUT # "secret-input"
|
||||
pm = MagicMock()
|
||||
pm.declaration.credentials_schema = [schema]
|
||||
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
||||
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY)
|
||||
assert "my_secret" in result
|
||||
|
||||
def test_should_extract_secret_variable_names_for_oauth2_schema(self, service):
|
||||
schema = MagicMock()
|
||||
schema.name = "oauth_secret"
|
||||
schema.type = MagicMock()
|
||||
schema.type.value = FormType.SECRET_INPUT
|
||||
pm = MagicMock()
|
||||
pm.declaration.oauth_schema.credentials_schema = [schema]
|
||||
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
||||
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2)
|
||||
assert "oauth_secret" in result
|
||||
|
||||
def test_should_raise_value_error_when_credential_type_is_invalid(self, service):
|
||||
pm = MagicMock()
|
||||
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
|
||||
with pytest.raises(ValueError, match="Invalid credential type"):
|
||||
service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# list_datasource_credentials (lines 721-754)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session):
|
||||
mock_db_session.query().all.return_value = []
|
||||
assert service.list_datasource_credentials("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "v"}
|
||||
p.is_default = False
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.list_datasource_credentials("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_all_datasource_credentials (lines 808-871)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service):
|
||||
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
|
||||
ds = MagicMock()
|
||||
ds.provider = "prov"
|
||||
ds.plugin_id = "org/plug"
|
||||
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"}
|
||||
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
|
||||
cred = {"credential": {"k": "v"}, "is_default": True}
|
||||
with patch.object(service, "list_datasource_credentials", return_value=[cred]):
|
||||
results = service.get_all_datasource_credentials("t1")
|
||||
assert len(results) == 1
|
||||
|
||||
def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session):
|
||||
"""Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs."""
|
||||
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
|
||||
ds = MagicMock()
|
||||
ds.plugin_id = "langgenius/firecrawl_datasource"
|
||||
ds.provider = "firecrawl"
|
||||
ds.plugin_unique_identifier = "pui"
|
||||
ds.declaration.identity.icon = "icon"
|
||||
ds.declaration.identity.name = "langgenius/firecrawl_datasource"
|
||||
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"}
|
||||
ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"}
|
||||
ds.declaration.identity.author = "langgenius"
|
||||
ds.declaration.credentials_schema = []
|
||||
ds.declaration.oauth_schema.client_schema = []
|
||||
ds.declaration.oauth_schema.credentials_schema = []
|
||||
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
|
||||
with (
|
||||
patch.object(service, "list_datasource_credentials", return_value=[]),
|
||||
patch.object(service, "get_tenant_oauth_client", return_value=None),
|
||||
patch.object(service, "is_tenant_oauth_params_enabled", return_value=False),
|
||||
patch.object(service, "is_system_oauth_params_exist", return_value=False),
|
||||
):
|
||||
results = service.get_all_datasource_credentials("t1")
|
||||
assert len(results) == 1
|
||||
assert results[0]["oauth_schema"] is not None
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# get_real_datasource_credentials (lines 873-915)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session):
|
||||
mock_db_session.query().all.return_value = []
|
||||
assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == []
|
||||
|
||||
def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "v"}
|
||||
mock_db_session.query().all.return_value = [p]
|
||||
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
|
||||
result = service.get_real_datasource_credentials("t1", "prov", "org/plug")
|
||||
assert len(result) == 1
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# update_datasource_credentials (lines 917-978)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
|
||||
mock_db_session.query().first.return_value = None
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
|
||||
|
||||
def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 1
|
||||
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
|
||||
|
||||
def test_should_raise_value_error_when_credential_validation_fails_on_update(
|
||||
self, service, mock_db_session, mock_user
|
||||
):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "e"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")),
|
||||
):
|
||||
with pytest.raises(ValueError, match="Failed to validate"):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name")
|
||||
|
||||
def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user):
|
||||
"""Verifies that encrypted_credentials is reassigned with encrypted value and commit is called."""
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
p.name = "old_name"
|
||||
p.auth_type = "api_key"
|
||||
p.encrypted_credentials = {"sk": "old_enc"}
|
||||
mock_db_session.query().first.return_value = p
|
||||
mock_db_session.query().count.return_value = 0
|
||||
with (
|
||||
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
|
||||
patch.object(service, "extract_secret_variables", return_value=["sk"]),
|
||||
patch.object(service.provider_manager, "validate_provider_credentials"),
|
||||
):
|
||||
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name")
|
||||
# encrypter must have been called with the new secret value
|
||||
self._enc.encrypt_token.assert_called()
|
||||
# commit must be called exactly once
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# remove_datasource_credentials (lines 980-997)
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session):
|
||||
p = MagicMock(spec=DatasourceProvider)
|
||||
mock_db_session.query().first.return_value = p
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_called_once_with(p)
|
||||
mock_db_session.commit.assert_called_once()
|
||||
|
||||
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
|
||||
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
|
||||
mock_db_session.query().first.return_value = None
|
||||
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
|
||||
mock_db_session.delete.assert_not_called()
|
||||
385
api/tests/unit_tests/services/test_hit_testing_service.py
Normal file
385
api/tests/unit_tests/services/test_hit_testing_service.py
Normal file
@ -0,0 +1,385 @@
|
||||
import json
|
||||
from typing import Any, cast
|
||||
from unittest.mock import ANY, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from models.dataset import Dataset
|
||||
from services.hit_testing_service import HitTestingService
|
||||
|
||||
|
||||
class TestHitTestingService:
|
||||
"""Test suite for HitTestingService"""
|
||||
|
||||
# ===== Utility Method Tests =====
|
||||
|
||||
def test_escape_query_for_search_should_escape_double_quotes(self):
|
||||
"""Test that escape_query_for_search escapes double quotes correctly"""
|
||||
# Arrange
|
||||
query = 'test "query" with quotes'
|
||||
expected = 'test \\"query\\" with quotes'
|
||||
|
||||
# Act
|
||||
result = HitTestingService.escape_query_for_search(query)
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_query(self):
|
||||
"""Test that hit_testing_args_check passes with a valid query"""
|
||||
# Arrange
|
||||
args = {"query": "valid query"}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
|
||||
"""Test that hit_testing_args_check passes with valid attachment_ids"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": ["id1", "id2"]}
|
||||
|
||||
# Act & Assert (should not raise)
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
|
||||
# Arrange
|
||||
args = {}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query or attachment_ids is required" in str(exc_info.value)
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
|
||||
# Arrange
|
||||
args = {"query": "a" * 251}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Query cannot exceed 250 characters" in str(exc_info.value)
|
||||
|
||||
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
|
||||
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
|
||||
# Arrange
|
||||
args = {"attachment_ids": "not a list"}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
HitTestingService.hit_testing_args_check(args)
|
||||
assert "Attachment_ids must be a list" in str(exc_info.value)
|
||||
|
||||
# ===== Response Formatting Tests =====
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
|
||||
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
|
||||
"""Test that compact_retrieve_response formats the response correctly"""
|
||||
# Arrange
|
||||
query = "test query"
|
||||
mock_doc = MagicMock(spec=Document)
|
||||
documents = [mock_doc]
|
||||
|
||||
mock_record = MagicMock()
|
||||
mock_record.model_dump.return_value = {"content": "formatted content"}
|
||||
mock_format.return_value = [mock_record]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 1
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
|
||||
mock_format.assert_called_once_with(documents)
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "external"
|
||||
query = "test query"
|
||||
documents = [
|
||||
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
|
||||
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
|
||||
]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert len(result["records"]) == 2
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
|
||||
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
|
||||
|
||||
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
documents = [{"content": "c1"}]
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== External Retrieve Tests =====
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
|
||||
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
dataset.provider = "external"
|
||||
query = 'test "query"'
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.external_retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=account,
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
|
||||
|
||||
# Verify call to RetrievalService.external_retrieve with escaped query
|
||||
mock_ext_retrieve.assert_called_once_with(
|
||||
dataset_id="dataset_id",
|
||||
query='test \\"query\\"',
|
||||
external_retrieval_model={"model": "test"},
|
||||
metadata_filtering_conditions={"key": "val"},
|
||||
)
|
||||
|
||||
# Verify DatasetQuery record was added and committed
|
||||
mock_add.assert_called_once()
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
|
||||
"""Test that external_retrieve returns empty results immediately if provider is not external"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.provider = "not_external"
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== Retrieve Tests =====
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve uses default model when retrieval_model is not provided"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
dataset.retrieval_model = None
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert cast(dict[str, Any], result["query"])["content"] == query
|
||||
mock_retrieve.assert_called_once()
|
||||
# Verify top_k from default_retrieval_model (4)
|
||||
assert mock_retrieve.call_args.kwargs["top_k"] == 4
|
||||
mock_commit.assert_called_once()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"metadata_filtering_conditions": {"some": "condition"},
|
||||
"top_k": 5,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response
|
||||
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_get_meta.assert_called_once()
|
||||
mock_retrieve.assert_called_once()
|
||||
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
|
||||
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
|
||||
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"metadata_filtering_conditions": {"some": "condition"},
|
||||
"top_k": 5,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
# Mock metadata filtering response: condition returned but no IDs
|
||||
mock_get_meta.return_value = ({}, "condition_string")
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any],
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
),
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
mock_retrieve.assert_not_called()
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
attachment_ids = ["att1", "att2"]
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"top_k": 4,
|
||||
"reranking_enable": False,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset,
|
||||
query=query,
|
||||
account=account,
|
||||
retrieval_model=retrieval_model,
|
||||
external_retrieval_model={},
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once_with(
|
||||
retrieval_method=ANY,
|
||||
dataset_id="dataset_id",
|
||||
query=query,
|
||||
attachment_ids=attachment_ids,
|
||||
top_k=4,
|
||||
score_threshold=0.0,
|
||||
reranking_model=None,
|
||||
reranking_mode="reranking_model",
|
||||
weights=None,
|
||||
document_ids_filter=None,
|
||||
)
|
||||
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
|
||||
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
|
||||
called_query = mock_add.call_args[0][0]
|
||||
query_content = json.loads(called_query.content)
|
||||
assert len(query_content) == 3 # 1 text + 2 images
|
||||
assert query_content[0]["content_type"] == "text_query"
|
||||
assert query_content[1]["content_type"] == "image_query"
|
||||
assert query_content[1]["content"] == "att1"
|
||||
|
||||
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
|
||||
@patch("extensions.ext_database.db.session.add")
|
||||
@patch("extensions.ext_database.db.session.commit")
|
||||
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
|
||||
"""Test that retrieve passes reranking and threshold parameters correctly"""
|
||||
# Arrange
|
||||
dataset = MagicMock(spec=Dataset)
|
||||
dataset.id = "dataset_id"
|
||||
query = "test query"
|
||||
account = MagicMock()
|
||||
account.id = "account_id"
|
||||
|
||||
retrieval_model = {
|
||||
"search_method": "hybrid_search",
|
||||
"top_k": 10,
|
||||
"reranking_enable": True,
|
||||
"reranking_model": {"provider": "test"},
|
||||
"reranking_mode": "weighted_sum",
|
||||
"score_threshold_enabled": True,
|
||||
"score_threshold": 0.5,
|
||||
"weights": {"vector": 0.5, "keyword": 0.5},
|
||||
}
|
||||
mock_retrieve.return_value = []
|
||||
|
||||
# Act
|
||||
HitTestingService.retrieve(
|
||||
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
|
||||
)
|
||||
|
||||
# Assert
|
||||
mock_retrieve.assert_called_once()
|
||||
kwargs = mock_retrieve.call_args.kwargs
|
||||
assert kwargs["score_threshold"] == 0.5
|
||||
assert kwargs["reranking_model"] == {"provider": "test"}
|
||||
assert kwargs["reranking_mode"] == "weighted_sum"
|
||||
assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5}
|
||||
146
api/tests/unit_tests/services/test_knowledge_service.py
Normal file
146
api/tests/unit_tests/services/test_knowledge_service.py
Normal file
@ -0,0 +1,146 @@
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.knowledge_service import ExternalDatasetTestService
|
||||
|
||||
|
||||
class TestKnowledgeService:
|
||||
"""Test suite for ExternalDatasetTestService"""
|
||||
|
||||
# ===== Happy Path Tests =====
|
||||
|
||||
@patch("services.knowledge_service.boto3.client")
|
||||
@patch("services.knowledge_service.dify_config")
|
||||
def test_knowledge_retrieval_should_succeed_with_valid_results(
|
||||
self, mock_dify_config: MagicMock, mock_boto_client: MagicMock
|
||||
):
|
||||
"""Test that knowledge_retrieval successfully parses results from Bedrock"""
|
||||
# Arrange
|
||||
mock_dify_config.AWS_SECRET_ACCESS_KEY = "dummy_secret"
|
||||
mock_dify_config.AWS_ACCESS_KEY_ID = "dummy_id"
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_boto_client.return_value = mock_client
|
||||
|
||||
retrieval_setting = {"top_k": 4, "score_threshold": 0.5}
|
||||
query = "test query"
|
||||
knowledge_id = "kb-123"
|
||||
|
||||
# Mock successful response
|
||||
mock_client.retrieve.return_value = {
|
||||
"ResponseMetadata": {"HTTPStatusCode": 200},
|
||||
"retrievalResults": [
|
||||
{
|
||||
"score": 0.9,
|
||||
"metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc1.pdf"},
|
||||
"content": {"text": "content from doc1"},
|
||||
},
|
||||
{
|
||||
"score": 0.4, # Below threshold
|
||||
"metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc2.pdf"},
|
||||
"content": {"text": "content from doc2"},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Act
|
||||
result = cast(
|
||||
dict[str, Any], ExternalDatasetTestService.knowledge_retrieval(retrieval_setting, query, knowledge_id)
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result["records"]) == 1
|
||||
record = result["records"][0]
|
||||
assert record["score"] == 0.9
|
||||
assert record["title"] == "s3://bucket/doc1.pdf"
|
||||
assert record["content"] == "content from doc1"
|
||||
|
||||
# verify retrieve called correctly
|
||||
mock_client.retrieve.assert_called_once_with(
|
||||
knowledgeBaseId=knowledge_id,
|
||||
retrievalConfiguration={
|
||||
"vectorSearchConfiguration": {
|
||||
"numberOfResults": 4,
|
||||
"overrideSearchType": "HYBRID",
|
||||
}
|
||||
},
|
||||
retrievalQuery={"text": query},
|
||||
)
|
||||
|
||||
# NEW: verify boto3.client created with proper service name and config values
|
||||
mock_boto_client.assert_called_once_with(
|
||||
"bedrock-agent-runtime",
|
||||
aws_secret_access_key="dummy_secret",
|
||||
aws_access_key_id="dummy_id",
|
||||
region_name="us-east-1",
|
||||
)
|
||||
|
||||
@patch("services.knowledge_service.boto3.client")
|
||||
def test_knowledge_retrieval_should_return_empty_when_no_results(self, mock_boto: MagicMock):
|
||||
"""Test that knowledge_retrieval returns empty records when Bedrock returns nothing"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_boto.return_value = mock_client
|
||||
|
||||
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []}
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
|
||||
# ===== Error Handling Tests =====
|
||||
|
||||
@patch("services.knowledge_service.boto3.client")
|
||||
def test_knowledge_retrieval_should_return_empty_on_http_error(self, mock_boto: MagicMock):
|
||||
"""Test that knowledge_retrieval returns empty records if Bedrock returns non-200 status"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_boto.return_value = mock_client
|
||||
|
||||
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}}
|
||||
|
||||
# Act
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
|
||||
# Assert
|
||||
assert result["records"] == []
|
||||
|
||||
def test_knowledge_retrieval_should_raise_when_boto_client_creation_fails(self):
|
||||
"""Test that exceptions from boto3.client propagate (e.g., network/credentials issues)"""
|
||||
with patch("services.knowledge_service.boto3.client") as mock_boto:
|
||||
mock_boto.side_effect = Exception("client init failed")
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")
|
||||
assert "client init failed" in str(exc_info.value)
|
||||
|
||||
# ===== Edge Cases =====
|
||||
|
||||
@patch("services.knowledge_service.boto3.client")
|
||||
def test_knowledge_retrieval_should_handle_missing_threshold_in_settings(self, mock_boto: MagicMock):
|
||||
"""Test that knowledge_retrieval uses 0.0 as default threshold if not provided"""
|
||||
# Arrange
|
||||
mock_client = MagicMock()
|
||||
mock_boto.return_value = mock_client
|
||||
|
||||
mock_client.retrieve.return_value = {
|
||||
"ResponseMetadata": {"HTTPStatusCode": 200},
|
||||
"retrievalResults": [
|
||||
{
|
||||
"score": 0.1,
|
||||
"metadata": {"x-amz-bedrock-kb-source-uri": "uri"},
|
||||
"content": {"text": "text"},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Act
|
||||
# retrieval_setting missing "score_threshold"
|
||||
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
|
||||
|
||||
# Assert
|
||||
assert len(result["records"]) == 1
|
||||
assert result["records"][0]["score"] == 0.1
|
||||
120
api/tests/unit_tests/services/test_operation_service.py
Normal file
120
api/tests/unit_tests/services/test_operation_service.py
Normal file
@ -0,0 +1,120 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from services.operation_service import OperationService
|
||||
|
||||
|
||||
class TestOperationService:
|
||||
"""Test suite for OperationService"""
|
||||
|
||||
# ===== Internal Method Tests =====
|
||||
|
||||
@patch("httpx.request")
|
||||
def test_should_call_with_correct_parameters_when__send_request_invoked(
|
||||
self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Test that _send_request calls httpx.request with the correct URL, headers, and data"""
|
||||
# Arrange
|
||||
monkeypatch.setattr(OperationService, "base_url", "https://billing.example")
|
||||
monkeypatch.setattr(OperationService, "secret_key", "s3cr3t")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
mock_request.return_value = mock_response
|
||||
|
||||
method = "POST"
|
||||
endpoint = "/test_endpoint"
|
||||
json_data = {"key": "value"}
|
||||
|
||||
# Act
|
||||
result = OperationService._send_request(method, endpoint, json=json_data)
|
||||
|
||||
# Assert
|
||||
assert result == {"status": "success"}
|
||||
|
||||
# Verify call parameters
|
||||
expected_url = "https://billing.example/test_endpoint"
|
||||
mock_request.assert_called_once()
|
||||
args, kwargs = mock_request.call_args
|
||||
assert args[0] == method
|
||||
assert args[1] == expected_url
|
||||
assert kwargs["json"] == json_data
|
||||
assert kwargs["headers"]["Billing-Api-Secret-Key"] == "s3cr3t"
|
||||
assert kwargs["headers"]["Content-Type"] == "application/json"
|
||||
|
||||
@patch("httpx.request")
|
||||
def test_should_propagate_httpx_error_when__send_request_raises(
|
||||
self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch
|
||||
):
|
||||
"""Test that _send_request handles httpx raising an error"""
|
||||
# Arrange
|
||||
monkeypatch.setattr(OperationService, "base_url", "https://billing.example")
|
||||
mock_request.side_effect = httpx.RequestError("network error")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(httpx.RequestError):
|
||||
OperationService._send_request("POST", "/test")
|
||||
|
||||
# ===== Public Method Tests =====
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("utm_info", "expected_params"),
|
||||
[
|
||||
(
|
||||
{
|
||||
"utm_source": "google",
|
||||
"utm_medium": "cpc",
|
||||
"utm_campaign": "spring_sale",
|
||||
"utm_content": "ad_1",
|
||||
"utm_term": "ai_agent",
|
||||
},
|
||||
{
|
||||
"tenant_id": "tenant-123",
|
||||
"utm_source": "google",
|
||||
"utm_medium": "cpc",
|
||||
"utm_campaign": "spring_sale",
|
||||
"utm_content": "ad_1",
|
||||
"utm_term": "ai_agent",
|
||||
},
|
||||
),
|
||||
(
|
||||
{}, # Empty utm_info
|
||||
{
|
||||
"tenant_id": "tenant-123",
|
||||
"utm_source": "",
|
||||
"utm_medium": "",
|
||||
"utm_campaign": "",
|
||||
"utm_content": "",
|
||||
"utm_term": "",
|
||||
},
|
||||
),
|
||||
(
|
||||
{"utm_source": "newsletter"}, # Partial utm_info
|
||||
{
|
||||
"tenant_id": "tenant-123",
|
||||
"utm_source": "newsletter",
|
||||
"utm_medium": "",
|
||||
"utm_campaign": "",
|
||||
"utm_content": "",
|
||||
"utm_term": "",
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch.object(OperationService, "_send_request")
|
||||
def test_should_map_parameters_correctly_when_record_utm_called(
|
||||
self, mock_send: MagicMock, utm_info: dict, expected_params: dict
|
||||
):
|
||||
"""Test that record_utm correctly maps utm_info to parameters and calls _send_request"""
|
||||
# Arrange
|
||||
tenant_id = "tenant-123"
|
||||
mock_send.return_value = {"status": "recorded"}
|
||||
|
||||
# Act
|
||||
result = OperationService.record_utm(tenant_id, utm_info)
|
||||
|
||||
# Assert
|
||||
assert result == {"status": "recorded"}
|
||||
mock_send.assert_called_once_with("POST", "/tenant_utms", params=expected_params)
|
||||
Reference in New Issue
Block a user