test: add unit tests for some services (#32866)

Co-authored-by: akashseth-ifp <akash.seth@infocusp.com>
This commit is contained in:
Poojan
2026-03-11 13:35:07 +05:30
committed by GitHub
parent f44cd70752
commit b2df0010ce
9 changed files with 4652 additions and 0 deletions

View File

@ -0,0 +1,214 @@
"""
Unit tests for services.advanced_prompt_template_service
"""
import copy
from core.prompt.prompt_templates.advanced_prompt_templates import (
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG,
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
BAICHUAN_CONTEXT,
CHAT_APP_CHAT_PROMPT_CONFIG,
CHAT_APP_COMPLETION_PROMPT_CONFIG,
COMPLETION_APP_CHAT_PROMPT_CONFIG,
COMPLETION_APP_COMPLETION_PROMPT_CONFIG,
CONTEXT,
)
from models.model import AppMode
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class TestAdvancedPromptTemplateService:
"""Test suite for AdvancedPromptTemplateService."""
def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None:
"""Test baichuan model names use baichuan context prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "chat",
"model_name": "Baichuan2-13B",
"has_context": "true",
}
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None:
"""Test non-baichuan model names use common prompt."""
# Arrange
args = {
"app_mode": AppMode.CHAT,
"model_mode": "completion",
"model_name": "gpt-4",
"has_context": "false",
}
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_prompt(args)
# Assert
assert result == original_config
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None:
"""Test invalid app mode returns empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "chat"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}
def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None:
"""Test context is prepended for completion prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT)
assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None:
"""Test context is prepended for chat prompt when has_context is true."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT)
assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None:
"""Test chat prompt remains unchanged when has_context is false."""
# Arrange
original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false")
# Assert
assert result == original_config
assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None:
"""Test completion app mode with completion model returns completion prompt."""
# Arrange
original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false")
# Assert
assert result == original_config
assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None:
"""Test invalid model mode returns empty dict."""
# Arrange
app_mode = AppMode.CHAT
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false")
# Assert
assert result == {}
def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps completion prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG)
original_text = prompt_template["completion_prompt_config"]["prompt"]["text"]
# Act
result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["completion_prompt_config"]["prompt"]["text"] == original_text
def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None:
"""Test helper keeps chat prompt unchanged when context is disabled."""
# Arrange
prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG)
original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"]
# Act
result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT)
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text
def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None:
"""Test baichuan chat/completion returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None:
"""Test baichuan completion/chat returns the expected config."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false")
# Assert
assert result == original_config
assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None:
"""Test baichuan completion/completion prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true")
# Assert
assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None:
"""Test baichuan chat/chat prepends baichuan context when enabled."""
# Arrange
original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG)
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true")
# Assert
assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT)
assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG
def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None:
"""Test invalid baichuan mode combinations return empty dict."""
# Arrange
app_mode = "invalid"
model_mode = "invalid"
# Act
result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true")
# Assert
assert result == {}

View File

@ -0,0 +1,346 @@
"""
Unit tests for services.agent_service
"""
from collections.abc import Callable
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
import pytz
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.model import App, Conversation, EndUser, Message, MessageAgentThought
from services.agent_service import AgentService
def _make_current_user_account(timezone: str = "UTC") -> Account:
account = Account(name="Test User", email="test@example.com")
account.timezone = timezone
return account
def _make_app_model(app_model_config: MagicMock | None) -> MagicMock:
app_model = MagicMock(spec=App)
app_model.id = "app-123"
app_model.tenant_id = "tenant-123"
app_model.app_model_config = app_model_config
return app_model
def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock:
conversation = MagicMock(spec=Conversation)
conversation.id = "conv-123"
conversation.app_id = "app-123"
conversation.from_end_user_id = from_end_user_id
conversation.from_account_id = from_account_id
return conversation
def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock:
message = MagicMock(spec=Message)
message.id = "msg-123"
message.conversation_id = "conv-123"
message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
message.provider_response_latency = 1.23
message.answer_tokens = 4
message.message_tokens = 6
message.agent_thoughts = agent_thoughts
message.message_files = ["file-a.txt"]
return message
def _make_agent_thought() -> MagicMock:
agent_thought = MagicMock(spec=MessageAgentThought)
agent_thought.tokens = 3
agent_thought.tool_input = "raw-input"
agent_thought.observation = "raw-output"
agent_thought.thought = "thinking"
agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC)
agent_thought.files = []
agent_thought.tools = ["tool_a", "dataset_tool"]
agent_thought.tool_labels = {"tool_a": "Tool A"}
agent_thought.tool_meta = {
"tool_a": {
"tool_config": {
"tool_provider_type": "custom",
"tool_provider": "provider-1",
},
"tool_parameters": {"param": "value"},
"time_cost": 2.5,
},
"dataset_tool": {
"tool_config": {
"tool_provider_type": "dataset-retrieval",
"tool_provider": "dataset-provider",
}
},
}
agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}}
agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}}
return agent_thought
def _build_query_side_effect(
conversation: Conversation | None,
message: Message | None,
executor: EndUser | Account | None,
) -> Callable[..., MagicMock]:
def _query_side_effect(*args: object, **kwargs: object) -> MagicMock:
query = MagicMock()
query.where.return_value = query
if any(arg is Conversation for arg in args):
query.first.return_value = conversation
elif any(arg is Message for arg in args):
query.first.return_value = message
elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args):
query.first.return_value = executor
return query
return _query_side_effect
class TestAgentServiceGetAgentLogs:
"""Test suite for AgentService.get_agent_logs."""
def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None:
"""Test missing conversation raises ValueError."""
# Arrange
app_model = _make_app_model(MagicMock())
with patch("services.agent_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, "missing-conv", "msg-1")
def test_get_agent_logs_should_raise_when_message_missing(self) -> None:
"""Test missing message raises ValueError."""
# Arrange
app_model = _make_app_model(MagicMock())
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
with patch("services.agent_service.db") as mock_db:
conversation_query = MagicMock()
conversation_query.where.return_value = conversation_query
conversation_query.first.return_value = conversation
message_query = MagicMock()
message_query.where.return_value = message_query
message_query.first.return_value = None
mock_db.session.query.side_effect = [conversation_query, message_query]
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, "missing-msg")
def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None:
"""Test missing app model config raises ValueError."""
# Arrange
app_model = _make_app_model(None)
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
message = _make_message([])
current_user = _make_current_user_account()
with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, message.id)
def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None:
"""Test missing agent config raises ValueError."""
# Arrange
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
message = _make_message([])
current_user = _make_current_user_account()
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=None),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock())
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_logs(app_model, conversation.id, message.id)
def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None:
"""Test agent logs returned for end-user executor with tool icons."""
# Arrange
agent_thought = _make_agent_thought()
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
executor = MagicMock(spec=EndUser)
executor.name = "End User"
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_tool = MagicMock()
agent_tool.tool_name = "tool_a"
agent_tool.provider_type = "custom"
agent_tool.provider_id = "provider-2"
agent_config = MagicMock()
agent_config.tools = [agent_tool]
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert,
patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon,
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
mock_get_icon.side_effect = [None, "icon-a"]
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["status"] == "success"
assert result["meta"]["executor"] == "End User"
assert result["meta"]["total_tokens"] == 10
assert result["meta"]["agent_mode"] == "react"
assert result["meta"]["iterations"] == 1
assert result["files"] == ["file-a.txt"]
assert len(result["iterations"]) == 1
tool_calls = result["iterations"][0]["tool_calls"]
assert tool_calls[0]["tool_name"] == "tool_a"
assert tool_calls[0]["tool_icon"] == "icon-a"
assert tool_calls[1]["tool_name"] == "dataset_tool"
assert tool_calls[1]["tool_icon"] == ""
mock_convert.assert_called_once()
def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None:
"""Test agent logs fall back to account executor when end user is missing."""
# Arrange
agent_thought = _make_agent_thought()
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1")
executor = MagicMock(spec=Account)
executor.name = "Account User"
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {"strategy": "react"}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_config = MagicMock()
agent_config.tools = []
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
patch("services.agent_service.ToolManager.get_tool_icon", return_value=""),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor)
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["executor"] == "Account User"
def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None:
"""Test unknown executor and missing tool details fall back to defaults."""
# Arrange
agent_thought = _make_agent_thought()
agent_thought.tool_labels = {}
agent_thought.tool_inputs_dict = {}
agent_thought.tool_outputs_dict = None
agent_thought.tool_meta = {"tool_a": {"error": "failed"}}
agent_thought.tools = ["tool_a"]
message = _make_message([agent_thought])
conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None)
app_model_config = MagicMock()
app_model_config.agent_mode_dict = {}
app_model_config.to_dict.return_value = {"tools": []}
app_model = _make_app_model(app_model_config)
current_user = _make_current_user_account()
agent_config = MagicMock()
agent_config.tools = []
with (
patch("services.agent_service.db") as mock_db,
patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config),
patch("services.agent_service.ToolManager.get_tool_icon", return_value=None),
patch("services.agent_service.current_user", current_user),
):
mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None)
# Act
result = AgentService.get_agent_logs(app_model, conversation.id, message.id)
# Assert
assert result["meta"]["executor"] == "Unknown"
assert result["meta"]["agent_mode"] == "react"
tool_call = result["iterations"][0]["tool_calls"][0]
assert tool_call["status"] == "error"
assert tool_call["error"] == "failed"
assert tool_call["tool_label"] == "tool_a"
assert tool_call["tool_input"] == {}
assert tool_call["tool_output"] == {}
assert tool_call["time_cost"] == 0
assert tool_call["tool_parameters"] == {}
assert tool_call["tool_icon"] is None
class TestAgentServiceProviders:
"""Test suite for AgentService provider methods."""
def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None:
"""Test list_agent_providers delegates to PluginAgentClient."""
# Arrange
tenant_id = "tenant-1"
expected = [{"name": "provider"}]
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_providers.return_value = expected
# Act
result = AgentService.list_agent_providers("user-1", tenant_id)
# Assert
assert result == expected
mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id)
def test_get_agent_provider_should_return_provider_when_successful(self) -> None:
"""Test get_agent_provider returns provider when successful."""
# Arrange
tenant_id = "tenant-1"
provider_name = "provider-a"
expected = {"name": provider_name}
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_provider.return_value = expected
# Act
result = AgentService.get_agent_provider("user-1", tenant_id, provider_name)
# Assert
assert result == expected
mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name)
def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None:
"""Test get_agent_provider wraps PluginDaemonClientSideError into ValueError."""
# Arrange
tenant_id = "tenant-1"
provider_name = "provider-a"
with patch("services.agent_service.PluginAgentClient") as mock_client:
mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError(
"plugin error"
)
# Act & Assert
with pytest.raises(ValueError):
AgentService.get_agent_provider("user-1", tenant_id, provider_name)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,609 @@
"""Unit tests for services.app_service."""
import json
from types import SimpleNamespace
from typing import cast
from unittest.mock import MagicMock, patch
import pytest
from core.errors.error import ProviderTokenNotInitError
from models import Account, Tenant
from models.model import App, AppMode
from services.app_service import AppService
@pytest.fixture
def service() -> AppService:
"""Provide AppService instance."""
return AppService()
@pytest.fixture
def account() -> Account:
"""Create account object for create_app tests."""
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
result = Account(name="Account User", email="account@example.com")
result.id = "acc-1"
result._current_tenant = tenant
return result
@pytest.fixture
def default_args() -> dict:
"""Create default create_app args."""
return {
"name": "Test App",
"mode": AppMode.CHAT.value,
"icon": "🤖",
"icon_background": "#FFFFFF",
}
@pytest.fixture
def app_template() -> dict:
"""Create basic app template for create_app tests."""
return {
AppMode.CHAT: {
"app": {},
"model_config": {
"model": {
"provider": "provider-a",
"name": "model-a",
"mode": "chat",
"completion_params": {},
}
},
}
}
def _make_current_user() -> Account:
user = Account(name="Tester", email="tester@example.com")
user.id = "user-1"
tenant = Tenant(name="Tenant")
tenant.id = "tenant-1"
user._current_tenant = tenant
return user
class TestAppServicePagination:
"""Test suite for get_paginate_apps."""
def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None:
"""Test pagination returns None when tag filter has no targets."""
# Arrange
args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]}
with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]):
# Act
result = service.get_paginate_apps("user-1", "tenant-1", args)
# Assert
assert result is None
def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None:
"""Test pagination delegates to db.paginate when filters are valid."""
# Arrange
args = {
"mode": "workflow",
"is_created_by_me": True,
"name": "My_App%",
"tag_ids": ["tag-1"],
"page": 2,
"limit": 10,
}
expected_pagination = MagicMock()
with (
patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]),
patch("libs.helper.escape_like_pattern", return_value="escaped"),
patch("services.app_service.db") as mock_db,
):
mock_db.paginate.return_value = expected_pagination
# Act
result = service.get_paginate_apps("user-1", "tenant-1", args)
# Assert
assert result is expected_pagination
mock_db.paginate.assert_called_once()
class TestAppServiceCreate:
"""Test suite for create_app."""
def test_create_app_should_create_with_matching_default_model(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app uses matching default model and persists app config."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
model_instance = SimpleNamespace(
model_name="model-a",
provider="provider-a",
model_type_instance=MagicMock(),
credentials={"k": "v"},
)
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
patch("services.app_service.app_was_created") as mock_event,
patch("services.app_service.FeatureService.get_system_features") as mock_features,
patch("services.app_service.BillingService") as mock_billing,
patch("services.app_service.dify_config") as mock_config,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.return_value = model_instance
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False))
mock_config.BILLING_ENABLED = True
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
assert app_instance.app_model_config_id == "cfg-1"
mock_db.session.add.assert_any_call(app_instance)
mock_db.session.add.assert_any_call(app_model_config)
assert mock_db.session.flush.call_count == 2
mock_db.session.commit.assert_called_once()
mock_event.send.assert_called_once_with(app_instance, account=account)
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
def test_create_app_should_raise_when_model_schema_missing(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app raises ValueError when non-matching model has no schema."""
# Arrange
app_instance = SimpleNamespace(id="app-1")
model_instance = SimpleNamespace(
model_name="model-b",
provider="provider-b",
model_type_instance=MagicMock(),
credentials={"k": "v"},
)
model_instance.model_type_instance.get_model_schema.return_value = None
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.return_value = model_instance
# Act & Assert
with pytest.raises(ValueError, match="model schema not found"):
service.create_app("tenant-1", default_args, account)
mock_db.session.commit.assert_not_called()
def test_create_app_should_fallback_to_default_provider_when_model_missing(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test create_app falls back to provider/model name when no default model instance is available."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db") as mock_db,
patch("services.app_service.app_was_created") as mock_event,
patch("services.app_service.FeatureService.get_system_features") as mock_features,
patch("services.app_service.EnterpriseService") as mock_enterprise,
patch("services.app_service.dify_config") as mock_config,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready")
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True))
mock_config.BILLING_ENABLED = False
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
mock_event.send.assert_called_once_with(app_instance, account=account)
mock_db.session.commit.assert_called_once()
mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private")
def test_create_app_should_log_and_fallback_on_unexpected_model_error(
self,
service: AppService,
account: Account,
default_args: dict,
app_template: dict,
) -> None:
"""Test unexpected model manager errors are logged and fallback provider is used."""
# Arrange
app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1")
app_model_config = SimpleNamespace(id="cfg-1")
with (
patch("services.app_service.default_app_templates", app_template),
patch("services.app_service.App", return_value=app_instance),
patch("services.app_service.AppModelConfig", return_value=app_model_config),
patch("services.app_service.ModelManager") as mock_model_manager,
patch("services.app_service.db"),
patch("services.app_service.app_was_created"),
patch(
"services.app_service.FeatureService.get_system_features",
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)),
),
patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)),
patch("services.app_service.logger") as mock_logger,
):
manager = mock_model_manager.return_value
manager.get_default_model_instance.side_effect = RuntimeError("boom")
manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model")
# Act
result = service.create_app("tenant-1", default_args, account)
# Assert
assert result is app_instance
mock_logger.exception.assert_called_once()
class TestAppServiceGetAndUpdate:
"""Test suite for app retrieval and update methods."""
def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None:
"""Test get_app returns original app for non-agent modes."""
# Arrange
app = MagicMock()
app.mode = AppMode.CHAT
app.is_agent = False
with patch("services.app_service.current_user", _make_current_user()):
# Act
result = service.get_app(app)
# Assert
assert result is app
def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None:
"""Test get_app returns app when agent mode has no model config."""
# Arrange
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = None
with patch("services.app_service.current_user", _make_current_user()):
# Act
result = service.get_app(app)
# Assert
assert result is app
def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None:
"""Test get_app decrypts and masks secret tool parameters."""
# Arrange
tool = {
"provider_type": "builtin",
"provider_id": "provider-1",
"tool_name": "tool-a",
"tool_parameters": {"secret": "encrypted"},
"extra": True,
}
model_config = MagicMock()
model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]}
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = model_config
manager = MagicMock()
manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"}
manager.mask_tool_parameters.return_value = {"secret": "***"}
with (
patch("services.app_service.current_user", _make_current_user()),
patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()),
patch("services.app_service.ToolParameterConfigurationManager", return_value=manager),
):
# Act
result = service.get_app(app)
# Assert
assert result.app_model_config is model_config
assert tool["tool_parameters"] == {"secret": "***"}
assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"}
def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None:
"""Test get_app logs and continues when masking fails."""
# Arrange
tool = {
"provider_type": "builtin",
"provider_id": "provider-1",
"tool_name": "tool-a",
"tool_parameters": {"secret": "encrypted"},
"extra": True,
}
model_config = MagicMock()
model_config.agent_mode_dict = {"tools": [tool]}
app = MagicMock()
app.id = "app-1"
app.mode = AppMode.AGENT_CHAT
app.is_agent = False
app.app_model_config = model_config
with (
patch("services.app_service.current_user", _make_current_user()),
patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")),
patch("services.app_service.logger") as mock_logger,
):
# Act
result = service.get_app(app)
# Assert
assert result.app_model_config is model_config
mock_logger.exception.assert_called_once()
def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None:
"""Test update methods set fields and commit changes."""
# Arrange
app = cast(
App,
SimpleNamespace(
name="old",
description="old",
icon_type="emoji",
icon="a",
icon_background="#111",
enable_site=True,
enable_api=True,
),
)
args = {
"name": "new",
"description": "new-desc",
"icon_type": "image",
"icon": "new-icon",
"icon_background": "#222",
"use_icon_as_answer_icon": True,
"max_active_requests": 5,
}
user = SimpleNamespace(id="user-1")
with (
patch("services.app_service.current_user", user),
patch("services.app_service.db") as mock_db,
patch("services.app_service.naive_utc_now", return_value="now"),
):
# Act
updated = service.update_app(app, args)
renamed = service.update_app_name(app, "rename")
iconed = service.update_app_icon(app, "icon-2", "#333")
site_same = service.update_app_site_status(app, app.enable_site)
api_same = service.update_app_api_status(app, app.enable_api)
site_changed = service.update_app_site_status(app, False)
api_changed = service.update_app_api_status(app, False)
# Assert
assert updated is app
assert renamed is app
assert iconed is app
assert site_same is app
assert api_same is app
assert site_changed is app
assert api_changed is app
assert mock_db.session.commit.call_count >= 5
class TestAppServiceDeleteAndMeta:
"""Test suite for delete and metadata methods."""
def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None:
"""Test delete_app removes app, runs cleanup, and triggers async deletion task."""
# Arrange
app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1"))
with (
patch("services.app_service.db") as mock_db,
patch(
"services.app_service.FeatureService.get_system_features",
return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)),
),
patch("services.app_service.EnterpriseService") as mock_enterprise,
patch(
"services.app_service.dify_config",
new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"),
),
patch("services.app_service.BillingService") as mock_billing,
patch("services.app_service.remove_app_and_related_data_task") as mock_task,
):
# Act
service.delete_app(app)
# Assert
mock_db.session.delete.assert_called_once_with(app)
mock_db.session.commit.assert_called_once()
mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1")
mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1")
mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1")
def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None:
"""Test get_app_meta extracts builtin and API tool icons from workflow graph."""
# Arrange
workflow = SimpleNamespace(
graph_dict={
"nodes": [
{
"data": {
"type": "tool",
"provider_type": "builtin",
"provider_id": "builtin-provider",
"tool_name": "tool_builtin",
}
},
{
"data": {
"type": "tool",
"provider_type": "api",
"provider_id": "api-provider-id",
"tool_name": "tool_api",
}
},
]
}
)
app = cast(
App,
SimpleNamespace(
mode=AppMode.WORKFLOW.value,
workflow=workflow,
app_model_config=None,
tenant_id="tenant-1",
icon_type="emoji",
icon_background="#fff",
),
)
provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"}))
with (
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
patch("services.app_service.db") as mock_db,
):
query = MagicMock()
query.where.return_value = query
query.first.return_value = provider
mock_db.session.query.return_value = query
# Act
meta = service.get_app_meta(app)
# Assert
assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon")
assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"}
def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None:
"""Test get_app_meta falls back to default icon when API provider lookup fails."""
# Arrange
app_model_config = SimpleNamespace(
agent_mode_dict={
"tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}]
}
)
app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None))
with (
patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")),
patch("services.app_service.db") as mock_db,
):
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act
meta = service.get_app_meta(app)
# Assert
assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"}
def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None:
"""Test get_app_meta returns empty metadata when workflow/model config is absent."""
# Arrange
workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None))
chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None))
# Act
workflow_meta = service.get_app_meta(workflow_app)
chat_meta = service.get_app_meta(chat_app)
# Assert
assert workflow_meta == {"tool_icons": {}}
assert chat_meta == {"tool_icons": {}}
class TestAppServiceCodeLookup:
"""Test suite for app code lookup methods."""
def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None:
"""Test get_app_code_by_id raises when site is missing."""
# Arrange
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError, match="not found"):
AppService.get_app_code_by_id("app-1")
def test_get_app_code_by_id_should_return_code(self) -> None:
"""Test get_app_code_by_id returns site code."""
# Arrange
site = SimpleNamespace(code="code-1")
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = site
mock_db.session.query.return_value = query
# Act
result = AppService.get_app_code_by_id("app-1")
# Assert
assert result == "code-1"
def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None:
"""Test get_app_id_by_code raises when code does not exist."""
# Arrange
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
# Act & Assert
with pytest.raises(ValueError, match="not found"):
AppService.get_app_id_by_code("missing")
def test_get_app_id_by_code_should_return_app_id(self) -> None:
"""Test get_app_id_by_code returns linked app id."""
# Arrange
site = SimpleNamespace(app_id="app-1")
with patch("services.app_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = site
mock_db.session.query.return_value = query
# Act
result = AppService.get_app_id_by_code("code-1")
# Assert
assert result == "app-1"

View File

@ -0,0 +1,387 @@
from dataclasses import asdict
from typing import Any, ClassVar, cast
from unittest.mock import MagicMock, patch
import pytest
from core.entities.document_task import DocumentTask
from enums.cloud_plan import CloudPlan
from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy
# ---------------------------------------------------------------------------
# Concrete subclass for testing (the base class is abstract)
# ---------------------------------------------------------------------------
class ConcreteBatchProxy(BatchDocumentIndexingProxy):
"""Minimal concrete implementation that provides the required class-level vars."""
QUEUE_NAME: ClassVar[str] = "test_queue"
NORMAL_TASK_FUNC: ClassVar[Any] = MagicMock(name="NORMAL_TASK_FUNC")
PRIORITY_TASK_FUNC: ClassVar[Any] = MagicMock(name="PRIORITY_TASK_FUNC")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
TENANT_ID = "tenant-abc"
DATASET_ID = "dataset-xyz"
DOC_IDS: list[str] = ["doc-1", "doc-2", "doc-3"]
def make_proxy(**kwargs: Any) -> ConcreteBatchProxy:
"""Factory: returns a ConcreteBatchProxy with TenantIsolatedTaskQueue mocked out."""
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue:
proxy = ConcreteBatchProxy(
tenant_id=kwargs.get("tenant_id", TENANT_ID),
dataset_id=kwargs.get("dataset_id", DATASET_ID),
document_ids=kwargs.get("document_ids", DOC_IDS),
)
# Expose the mock queue on the proxy so tests can assert on it
proxy._tenant_isolated_task_queue = MockQueue.return_value
return proxy
# ---------------------------------------------------------------------------
# Test suite
# ---------------------------------------------------------------------------
class TestBatchDocumentIndexingProxyInit:
"""Tests for __init__ of BatchDocumentIndexingProxy."""
def test_should_store_document_ids_when_initialized(self) -> None:
"""Verify that document_ids are stored on the proxy instance."""
# Arrange
doc_ids: list[str] = ["doc-a", "doc-b"]
# Act
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids)
# Assert
assert proxy._document_ids == doc_ids
def test_should_propagate_tenant_and_dataset_to_base_when_initialized(self) -> None:
"""Verify that tenant_id and dataset_id are forwarded to the parent class."""
# Arrange / Act
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS)
# Assert
assert proxy._tenant_id == TENANT_ID
assert proxy._dataset_id == DATASET_ID
def test_should_create_tenant_isolated_queue_with_correct_args_when_initialized(self) -> None:
"""Verify that TenantIsolatedTaskQueue is constructed with (tenant_id, QUEUE_NAME)."""
# Arrange / Act
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue:
ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS)
# Assert
MockQueue.assert_called_once_with(TENANT_ID, ConcreteBatchProxy.QUEUE_NAME)
@pytest.mark.parametrize("doc_ids", [[], ["single-doc"], ["d1", "d2", "d3", "d4"]])
def test_should_accept_any_length_document_ids_when_initialized(self, doc_ids: list[str]) -> None:
"""Verify that empty, single, and multiple document IDs are all accepted."""
# Arrange / Act
with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"):
proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids)
# Assert
assert list(proxy._document_ids) == doc_ids
class TestSendToDirectQueue:
"""Tests for _send_to_direct_queue."""
def test_should_call_task_func_delay_with_correct_args_when_sent_to_direct_queue(
self,
) -> None:
"""Verify that task_func.delay is called with the right kwargs."""
# Arrange
proxy = make_proxy()
task_func = MagicMock()
# Act
proxy._send_to_direct_queue(task_func)
# Assert
task_func.delay.assert_called_once_with(
tenant_id=TENANT_ID,
dataset_id=DATASET_ID,
document_ids=DOC_IDS,
)
def test_should_not_interact_with_tenant_queue_when_sent_to_direct_queue(self) -> None:
"""Direct queue path must never touch the tenant-isolated queue."""
# Arrange
proxy = make_proxy()
task_func = MagicMock()
# Act
proxy._send_to_direct_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
mock_queue.push_tasks.assert_not_called()
mock_queue.set_task_waiting_time.assert_not_called()
def test_should_forward_any_callable_when_sent_to_direct_queue(self) -> None:
"""Verify that different task functions are each called correctly."""
# Arrange
proxy = make_proxy()
task_a, task_b = MagicMock(), MagicMock()
# Act
proxy._send_to_direct_queue(task_a)
proxy._send_to_direct_queue(task_b)
# Assert
task_a.delay.assert_called_once()
task_b.delay.assert_called_once()
class TestSendToTenantQueue:
"""Tests for _send_to_tenant_queue — both branches."""
# ------------------------------------------------------------------
# Branch 1: get_task_key() is truthy → push to waiting queue
# ------------------------------------------------------------------
def test_should_push_task_to_queue_when_task_key_exists(self) -> None:
"""When get_task_key() is truthy, tasks must be pushed via push_tasks()."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
expected_payload = [asdict(DocumentTask(tenant_id=TENANT_ID, dataset_id=DATASET_ID, document_ids=DOC_IDS))]
mock_queue.push_tasks.assert_called_once_with(expected_payload)
def test_should_not_call_task_func_delay_when_task_key_exists(self) -> None:
"""When a key already exists, task_func.delay must never be called."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
cast(MagicMock, task_func.delay).assert_not_called()
def test_should_not_set_waiting_time_when_task_key_exists(self) -> None:
"""When a key already exists, set_task_waiting_time must never be called."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key"
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
mock_queue.set_task_waiting_time.assert_not_called()
def test_should_serialize_document_task_correctly_when_pushing_to_queue(self) -> None:
"""Verify the serialised payload matches asdict(DocumentTask(...))."""
# Arrange
proxy = make_proxy(document_ids=["doc-x"])
proxy._tenant_isolated_task_queue.get_task_key.return_value = "k"
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert — inspect the payload passed to push_tasks
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
call_args = mock_queue.push_tasks.call_args
pushed_list = call_args[0][0] # first positional arg
assert len(pushed_list) == 1
assert pushed_list[0]["tenant_id"] == TENANT_ID
assert pushed_list[0]["dataset_id"] == DATASET_ID
assert pushed_list[0]["document_ids"] == ["doc-x"]
# ------------------------------------------------------------------
# Branch 2: get_task_key() is falsy → set flag + dispatch via delay
# ------------------------------------------------------------------
def test_should_set_waiting_time_and_call_delay_when_no_task_key(self) -> None:
"""When get_task_key() is falsy, set_task_waiting_time and task_func.delay are invoked."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
mock_queue.set_task_waiting_time.assert_called_once()
cast(MagicMock, task_func.delay).assert_called_once_with(
tenant_id=TENANT_ID,
dataset_id=DATASET_ID,
document_ids=DOC_IDS,
)
def test_should_not_push_tasks_when_no_task_key(self) -> None:
"""When get_task_key() is falsy, push_tasks must never be called."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
mock_queue.push_tasks.assert_not_called()
@pytest.mark.parametrize("falsy_key", [None, "", 0, False])
def test_should_init_task_when_key_is_any_falsy_value(self, falsy_key: Any) -> None:
"""Verify that any falsy return from get_task_key() triggers the init branch."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = falsy_key
task_func = MagicMock()
# Act
proxy._send_to_tenant_queue(task_func)
# Assert
mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue)
mock_queue.set_task_waiting_time.assert_called_once()
cast(MagicMock, task_func.delay).assert_called_once()
class TestDispatchRouting:
"""Tests for the _dispatch / delay routing logic inherited from the base class."""
def _mock_features(self, enabled: bool, plan: CloudPlan) -> MagicMock:
features = MagicMock()
features.billing.enabled = enabled
features.billing.subscription.plan = plan
return features
def test_should_send_to_normal_tenant_queue_when_billing_enabled_and_sandbox_plan(self) -> None:
"""Sandbox plan routes to normal priority queue with tenant isolation."""
# Arrange
proxy = make_proxy()
proxy._tenant_isolated_task_queue.get_task_key.return_value = None
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.SANDBOX)
# Act
with patch.object(proxy, "_send_to_default_tenant_queue") as mock_method:
proxy._dispatch()
# Assert
mock_method.assert_called_once()
def test_should_send_to_priority_tenant_queue_when_billing_enabled_and_paid_plan(self) -> None:
"""Non-sandbox paid plan routes to priority queue with tenant isolation."""
# Arrange
proxy = make_proxy()
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.PROFESSIONAL)
# Act
with patch.object(proxy, "_send_to_priority_tenant_queue") as mock_method:
proxy._dispatch()
# Assert
mock_method.assert_called_once()
def test_should_send_to_priority_direct_queue_when_billing_not_enabled(self) -> None:
"""Self-hosted / no billing → priority direct queue (no tenant isolation)."""
# Arrange
proxy = make_proxy()
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX)
# Act
with patch.object(proxy, "_send_to_priority_direct_queue") as mock_method:
proxy._dispatch()
# Assert
mock_method.assert_called_once()
def test_should_call_dispatch_when_delay_is_invoked(self) -> None:
"""Calling delay() must invoke _dispatch() exactly once."""
# Arrange
proxy = make_proxy()
# Act
with patch.object(proxy, "_dispatch") as mock_dispatch:
proxy.delay()
# Assert
mock_dispatch.assert_called_once()
def test_should_use_feature_service_for_billing_info(self) -> None:
"""Verify that FeatureService.get_features is consulted during dispatch."""
# Arrange
proxy = make_proxy()
with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features:
mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX)
with patch.object(proxy, "_send_to_priority_direct_queue"):
# Act
proxy._dispatch()
# Assert
mock_features.assert_called_once_with(TENANT_ID)
class TestBaseRouterHelpers:
"""Tests for the three routing helper methods from the base class."""
def test_should_call_send_to_tenant_queue_with_normal_func_when_default_tenant_queue(self) -> None:
"""_send_to_default_tenant_queue must forward NORMAL_TASK_FUNC."""
# Arrange
proxy = make_proxy()
# Act
with patch.object(proxy, "_send_to_tenant_queue") as mock_method:
proxy._send_to_default_tenant_queue()
# Assert
mock_method.assert_called_once_with(ConcreteBatchProxy.NORMAL_TASK_FUNC)
def test_should_call_send_to_tenant_queue_with_priority_func_when_priority_tenant_queue(self) -> None:
"""_send_to_priority_tenant_queue must forward PRIORITY_TASK_FUNC."""
# Arrange
proxy = make_proxy()
# Act
with patch.object(proxy, "_send_to_tenant_queue") as mock_method:
proxy._send_to_priority_tenant_queue()
# Assert
mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC)
def test_should_call_send_to_direct_queue_with_priority_func_when_priority_direct_queue(self) -> None:
"""_send_to_priority_direct_queue must forward PRIORITY_TASK_FUNC."""
# Arrange
proxy = make_proxy()
# Act
with patch.object(proxy, "_send_to_direct_queue") as mock_method:
proxy._send_to_priority_direct_queue()
# Assert
mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC)

View File

@ -0,0 +1,760 @@
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy.orm import Session
from core.plugin.entities.plugin_daemon import CredentialType
from dify_graph.model_runtime.entities.provider_entities import FormType
from models.account import Account
from models.model import EndUser
from models.oauth import DatasourceProvider
from models.provider_ids import DatasourceProviderID
from services.datasource_provider_service import DatasourceProviderService, get_current_user
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID:
return DatasourceProviderID(s)
# ---------------------------------------------------------------------------
# Test class
# ---------------------------------------------------------------------------
class TestDatasourceProviderService:
"""Comprehensive tests for DatasourceProviderService targeting >95% coverage."""
@pytest.fixture
def service(self):
return DatasourceProviderService()
@pytest.fixture
def mock_db_session(self):
"""
Robust, chainable query mock.
q returns itself for .filter_by(), .order_by(), .where() so any
SQLAlchemy chaining pattern works without multiple brittle sub-mocks.
"""
with patch("services.datasource_provider_service.Session") as mock_cls:
sess = MagicMock(spec=Session)
q = MagicMock()
sess.query.return_value = q
# Self-returning chain — any method called on q returns q
q.filter_by.return_value = q
q.order_by.return_value = q
q.where.return_value = q
# Default terminal values (tests override per-case)
q.first.return_value = None
q.all.return_value = []
q.count.return_value = 0
q.delete.return_value = 1
mock_cls.return_value.__enter__.return_value = sess
mock_cls.return_value.no_autoflush.__enter__.return_value = sess
yield sess
@pytest.fixture(autouse=True)
def patch_db(self, mock_db_session):
with patch("services.datasource_provider_service.db") as mock_db:
mock_db.session = mock_db_session
mock_db.engine = MagicMock()
yield mock_db
@pytest.fixture(autouse=True)
def patch_externals(self):
with (
patch("httpx.request") as mock_httpx,
patch("services.datasource_provider_service.dify_config") as mock_cfg,
patch("services.datasource_provider_service.encrypter") as mock_enc,
patch("services.datasource_provider_service.redis_client") as mock_redis,
patch("services.datasource_provider_service.generate_incremental_name") as mock_genname,
patch("services.datasource_provider_service.OAuthHandler") as mock_oauth,
):
mock_cfg.CONSOLE_API_URL = "http://localhost"
mock_enc.encrypt_token.return_value = "enc_tok"
mock_enc.decrypt_token.return_value = "dec_tok"
mock_enc.decrypt.return_value = {"k": "dec"}
mock_enc.encrypt.return_value = {"k": "enc"}
mock_enc.obfuscated_token.return_value = "obf"
mock_enc.mask_plugin_credentials.return_value = {"k": "mask"}
mock_redis.lock.return_value.__enter__.return_value = MagicMock()
mock_genname.return_value = "gen_name"
mock_oauth.return_value.refresh_credentials.return_value = MagicMock(
credentials={"k": "v"}, expires_at=9999
)
resp = MagicMock()
resp.status_code = 200
resp.json.return_value = {
"code": 0,
"message": "ok",
"data": {
"provider": "prov",
"plugin_unique_identifier": "pui",
"plugin_id": "org/plug",
"is_authorized": False,
"declaration": {
"identity": {
"author": "a",
"name": "n",
"description": {"en_US": "d"},
"icon": "i",
"label": {"en_US": "l"},
},
"credentials_schema": [],
"oauth_schema": {"credentials_schema": [], "client_schema": []},
"provider_type": "local_file",
"datasources": [],
},
},
}
mock_httpx.return_value = resp
# Store handles for assertions
self._enc = mock_enc
self._redis = mock_redis
yield
@pytest.fixture
def mock_user(self):
u = MagicMock()
u.id = "uid-1"
return u
# -----------------------------------------------------------------------
# get_current_user (lines 27-40)
# -----------------------------------------------------------------------
def test_should_return_proxy_when_current_object_is_account(self):
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
user_obj = MagicMock()
user_obj.__class__ = Account
proxy._get_current_object.return_value = user_obj
assert get_current_user() is proxy
def test_should_return_proxy_when_current_object_is_enduser(self):
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
user_obj = MagicMock()
user_obj.__class__ = EndUser
proxy._get_current_object.return_value = user_obj
assert get_current_user() is proxy
def test_should_return_proxy_when_get_current_object_raises_attribute_error(self):
"""AttributeError from LocalProxy falls back to the proxy itself."""
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
proxy._get_current_object.side_effect = AttributeError("no attr")
proxy.__class__ = Account # make the proxy itself satisfy isinstance
assert get_current_user() is proxy
def test_should_raise_type_error_when_user_is_not_account_or_enduser(self):
with patch("libs.login.current_user", new_callable=MagicMock) as proxy:
proxy._get_current_object.return_value = "plain_string"
with pytest.raises(TypeError, match="current_user must be Account or EndUser"):
get_current_user()
# -----------------------------------------------------------------------
# is_system_oauth_params_exist (line 357-363)
# -----------------------------------------------------------------------
def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock()
assert service.is_system_oauth_params_exist(make_id()) is True
def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
assert service.is_system_oauth_params_exist(make_id()) is False
# -----------------------------------------------------------------------
# is_tenant_oauth_params_enabled (lines 365-379)
# NOTE: uses .count() not .first()
# -----------------------------------------------------------------------
def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session):
mock_db_session.query().count.return_value = 1
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True
def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False
# -----------------------------------------------------------------------
# remove_oauth_custom_client_params (lines 55-61)
# -----------------------------------------------------------------------
def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session):
service.remove_oauth_custom_client_params("t1", make_id())
mock_db_session.query().delete.assert_called_once()
# -----------------------------------------------------------------------
# setup_oauth_custom_client_params (315-351)
# -----------------------------------------------------------------------
def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session):
"""When credentials=None, should return immediately without any DB write."""
service.setup_oauth_custom_client_params("t1", make_id(), None, None)
mock_db_session.add.assert_not_called()
def test_should_create_new_config_when_none_exists(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True)
mock_db_session.add.assert_called_once()
def test_should_update_existing_config_when_record_found(self, service, mock_db_session):
existing = MagicMock()
mock_db_session.query().first.return_value = existing
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False)
mock_db_session.add.assert_not_called() # update in place, no add
# -----------------------------------------------------------------------
# decrypt / encrypt credentials (lines 70-98)
# -----------------------------------------------------------------------
def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "enc_val"}
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov")
assert result["sk"] == "dec_tok"
def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p)
assert result["sk"] == "enc_tok"
self._enc.encrypt_token.assert_called()
# -----------------------------------------------------------------------
# get_datasource_credentials (lines 113-165)
# -----------------------------------------------------------------------
def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().first.return_value = None
assert service.get_datasource_credentials("t1", "prov", "org/plug") == {}
def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user):
"""Expired OAuth credential (expires_at near zero) triggers a silent refresh."""
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "oauth2"
p.expires_at = 0 # expired
p.encrypted_credentials = {"tok": "x"}
mock_db_session.query().first.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}),
):
service.get_datasource_credentials("t1", "prov", "org/plug")
mock_db_session.commit.assert_called_once()
def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user):
"""API key credentials with expires_at=-1 skip refresh and return directly."""
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.expires_at = -1 # sentinel: never expires
p.encrypted_credentials = {"k": "v"}
mock_db_session.query().first.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}),
):
result = service.get_datasource_credentials("t1", "prov", "org/plug")
assert result == {"k": "plain"}
def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user):
"""When credential_id is passed, the credential_id filter path (line 113) is taken."""
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.expires_at = -1
p.encrypted_credentials = {}
mock_db_session.query().first.return_value = p
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}),
):
result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id")
assert result == {"k": "v"}
# -----------------------------------------------------------------------
# get_all_datasource_credentials_by_provider (lines 176-228)
# -----------------------------------------------------------------------
def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user):
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
mock_db_session.query().all.return_value = []
assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == []
def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "oauth2"
p.expires_at = 0
p.encrypted_credentials = {"t": "x"}
mock_db_session.query().all.return_value = [p]
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "get_oauth_client", return_value={"oc": "v"}),
patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}),
):
result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug")
assert len(result) == 1
# -----------------------------------------------------------------------
# update_datasource_provider_name (lines 236-303)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
with pytest.raises(ValueError, match="not found"):
service.update_datasource_provider_name("t1", make_id(), "new", "cred-id")
def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "same"
mock_db_session.query().first.return_value = p
service.update_datasource_provider_name("t1", make_id(), "same", "cred-id")
mock_db_session.commit.assert_not_called()
def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.is_default = False
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id")
assert p.name == "new_name"
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# set_default_datasource_provider (lines 277-303)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
with pytest.raises(ValueError, match="not found"):
service.set_default_datasource_provider("t1", make_id(), "bad-id")
def test_should_mark_target_as_default_and_commit(self, service, mock_db_session):
target = MagicMock(spec=DatasourceProvider)
target.provider = "provider"
target.plugin_id = "org/plug"
mock_db_session.query().first.return_value = target
service.set_default_datasource_provider("t1", make_id(), "new-id")
assert target.is_default is True
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# get_oauth_encrypter (lines 404-420)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_oauth_schema_missing(self, service):
pm = MagicMock()
pm.declaration.oauth_schema = None
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
with pytest.raises(ValueError, match="oauth schema not found"):
service.get_oauth_encrypter("t1", make_id())
def test_should_return_encrypter_when_oauth_schema_exists(self, service):
schema_item = MagicMock()
schema_item.to_basic_provider_config.return_value = MagicMock()
pm = MagicMock()
pm.declaration.oauth_schema.client_schema = [schema_item]
with (
patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm),
patch(
"services.datasource_provider_service.create_provider_encrypter",
return_value=(MagicMock(), MagicMock()),
),
):
result = service.get_oauth_encrypter("t1", make_id())
assert result is not None
# -----------------------------------------------------------------------
# get_tenant_oauth_client (lines 381-402)
# -----------------------------------------------------------------------
def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session):
tenant_params = MagicMock()
tenant_params.client_params = {"k": "v"}
mock_db_session.query().first.return_value = tenant_params
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_tenant_oauth_client("t1", make_id(), mask=True)
assert result == {"k": "mask"}
def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session):
tenant_params = MagicMock()
tenant_params.client_params = {"k": "v"}
mock_db_session.query().first.return_value = tenant_params
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_tenant_oauth_client("t1", make_id(), mask=False)
assert result == {"k": "dec"}
def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
assert service.get_tenant_oauth_client("t1", make_id()) is None
# -----------------------------------------------------------------------
# get_oauth_client (lines 423-457)
# -----------------------------------------------------------------------
def test_should_use_tenant_config_when_available(self, service, mock_db_session):
mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"})
with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)):
result = service.get_oauth_client("t1", make_id())
assert result == {"k": "dec"}
def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session):
mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True),
):
result = service.get_oauth_client("t1", make_id())
assert result == {"k": "sys"}
def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session):
"""Neither tenant nor system credentials → raises ValueError."""
mock_db_session.query().first.side_effect = [None, None]
with (
patch.object(service.provider_manager, "fetch_datasource_provider"),
patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False),
):
with pytest.raises(ValueError, match="Please configure oauth client params"):
service.get_oauth_client("t1", make_id())
# -----------------------------------------------------------------------
# add_datasource_oauth_provider (lines 539-607)
# -----------------------------------------------------------------------
def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session):
"""Conflict on name results in auto-incremented name, not an error."""
mock_db_session.query().count.return_value = 1 # conflict first, then auto-named
mock_db_session.query().all.return_value = []
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"),
):
service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session):
"""name=None causes auto-generation via generate_next_datasource_provider_name."""
mock_db_session.query().count.return_value = 0
mock_db_session.query().all.return_value = []
with (
patch.object(service, "extract_secret_variables", return_value=[]),
patch.object(service, "generate_next_datasource_provider_name", return_value="auto"),
):
service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {})
mock_db_session.add.assert_called_once()
def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=["secret_key"]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"})
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session):
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {})
self._redis.lock.assert_called()
# -----------------------------------------------------------------------
# reauthorize_datasource_oauth_provider (lines 477-537)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session):
mock_db_session.query().first.return_value = None
with patch.object(service, "extract_secret_variables", return_value=[]):
with pytest.raises(ValueError, match="not found"):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id")
def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
mock_db_session.commit.assert_called_once()
def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1 # conflict
mock_db_session.query().all.return_value = []
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(
"conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id"
)
mock_db_session.commit.assert_called_once()
def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=["tok"]):
service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id")
self._enc.encrypt_token.assert_called()
def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
with patch.object(service, "extract_secret_variables", return_value=[]):
service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid")
self._redis.lock.assert_called()
# -----------------------------------------------------------------------
# add_datasource_api_key_provider (lines 608-675)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user):
"""explicit name supplied + conflict → raises ValueError immediately."""
mock_db_session.query().count.return_value = 1
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"})
def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")),
patch.object(service, "extract_secret_variables", return_value=[]),
):
with pytest.raises(ValueError, match="Failed to validate"):
service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"})
def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
patch.object(service, "extract_secret_variables", return_value=["sk"]),
):
service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"})
mock_db_session.add.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user):
mock_db_session.query().count.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service.provider_manager, "validate_provider_credentials"),
patch.object(service, "extract_secret_variables", return_value=[]),
):
service.add_datasource_api_key_provider(None, "t1", make_id(), {})
self._redis.lock.assert_called()
# -----------------------------------------------------------------------
# extract_secret_variables (lines 666-699)
# -----------------------------------------------------------------------
def test_should_extract_secret_variable_names_for_api_key_schema(self, service):
schema = MagicMock()
schema.name = "my_secret"
schema.type = MagicMock()
schema.type.value = FormType.SECRET_INPUT # "secret-input"
pm = MagicMock()
pm.declaration.credentials_schema = [schema]
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY)
assert "my_secret" in result
def test_should_extract_secret_variable_names_for_oauth2_schema(self, service):
schema = MagicMock()
schema.name = "oauth_secret"
schema.type = MagicMock()
schema.type.value = FormType.SECRET_INPUT
pm = MagicMock()
pm.declaration.oauth_schema.credentials_schema = [schema]
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2)
assert "oauth_secret" in result
def test_should_raise_value_error_when_credential_type_is_invalid(self, service):
pm = MagicMock()
with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm):
with pytest.raises(ValueError, match="Invalid credential type"):
service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED)
# -----------------------------------------------------------------------
# list_datasource_credentials (lines 721-754)
# -----------------------------------------------------------------------
def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session):
mock_db_session.query().all.return_value = []
assert service.list_datasource_credentials("t1", "prov", "org/plug") == []
def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "v"}
p.is_default = False
mock_db_session.query().all.return_value = [p]
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.list_datasource_credentials("t1", "prov", "org/plug")
assert len(result) == 1
# -----------------------------------------------------------------------
# get_all_datasource_credentials (lines 808-871)
# -----------------------------------------------------------------------
def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service):
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
ds = MagicMock()
ds.provider = "prov"
ds.plugin_id = "org/plug"
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"}
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
cred = {"credential": {"k": "v"}, "is_default": True}
with patch.object(service, "list_datasource_credentials", return_value=[cred]):
results = service.get_all_datasource_credentials("t1")
assert len(results) == 1
def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session):
"""Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs."""
with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr:
ds = MagicMock()
ds.plugin_id = "langgenius/firecrawl_datasource"
ds.provider = "firecrawl"
ds.plugin_unique_identifier = "pui"
ds.declaration.identity.icon = "icon"
ds.declaration.identity.name = "langgenius/firecrawl_datasource"
ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"}
ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"}
ds.declaration.identity.author = "langgenius"
ds.declaration.credentials_schema = []
ds.declaration.oauth_schema.client_schema = []
ds.declaration.oauth_schema.credentials_schema = []
mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds]
with (
patch.object(service, "list_datasource_credentials", return_value=[]),
patch.object(service, "get_tenant_oauth_client", return_value=None),
patch.object(service, "is_tenant_oauth_params_enabled", return_value=False),
patch.object(service, "is_system_oauth_params_exist", return_value=False),
):
results = service.get_all_datasource_credentials("t1")
assert len(results) == 1
assert results[0]["oauth_schema"] is not None
# -----------------------------------------------------------------------
# get_real_datasource_credentials (lines 873-915)
# -----------------------------------------------------------------------
def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session):
mock_db_session.query().all.return_value = []
assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == []
def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "v"}
mock_db_session.query().all.return_value = [p]
with patch.object(service, "extract_secret_variables", return_value=["sk"]):
result = service.get_real_datasource_credentials("t1", "prov", "org/plug")
assert len(result) == 1
# -----------------------------------------------------------------------
# update_datasource_credentials (lines 917-978)
# -----------------------------------------------------------------------
def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user):
mock_db_session.query().first.return_value = None
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="not found"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name")
def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 1
with patch("services.datasource_provider_service.get_current_user", return_value=mock_user):
with pytest.raises(ValueError, match="already exists"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name")
def test_should_raise_value_error_when_credential_validation_fails_on_update(
self, service, mock_db_session, mock_user
):
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "e"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),
patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")),
):
with pytest.raises(ValueError, match="Failed to validate"):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name")
def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user):
"""Verifies that encrypted_credentials is reassigned with encrypted value and commit is called."""
p = MagicMock(spec=DatasourceProvider)
p.name = "old_name"
p.auth_type = "api_key"
p.encrypted_credentials = {"sk": "old_enc"}
mock_db_session.query().first.return_value = p
mock_db_session.query().count.return_value = 0
with (
patch("services.datasource_provider_service.get_current_user", return_value=mock_user),
patch.object(service, "extract_secret_variables", return_value=["sk"]),
patch.object(service.provider_manager, "validate_provider_credentials"),
):
service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name")
# encrypter must have been called with the new secret value
self._enc.encrypt_token.assert_called()
# commit must be called exactly once
mock_db_session.commit.assert_called_once()
# -----------------------------------------------------------------------
# remove_datasource_credentials (lines 980-997)
# -----------------------------------------------------------------------
def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session):
p = MagicMock(spec=DatasourceProvider)
mock_db_session.query().first.return_value = p
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_called_once_with(p)
mock_db_session.commit.assert_called_once()
def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session):
"""No error raised; no delete called when record doesn't exist (lines 994 branch)."""
mock_db_session.query().first.return_value = None
service.remove_datasource_credentials("t1", "id", "prov", "org/plug")
mock_db_session.delete.assert_not_called()

View File

@ -0,0 +1,385 @@
import json
from typing import Any, cast
from unittest.mock import ANY, MagicMock, patch
import pytest
from core.rag.models.document import Document
from models.dataset import Dataset
from services.hit_testing_service import HitTestingService
class TestHitTestingService:
"""Test suite for HitTestingService"""
# ===== Utility Method Tests =====
def test_escape_query_for_search_should_escape_double_quotes(self):
"""Test that escape_query_for_search escapes double quotes correctly"""
# Arrange
query = 'test "query" with quotes'
expected = 'test \\"query\\" with quotes'
# Act
result = HitTestingService.escape_query_for_search(query)
# Assert
assert result == expected
def test_hit_testing_args_check_should_pass_with_valid_query(self):
"""Test that hit_testing_args_check passes with a valid query"""
# Arrange
args = {"query": "valid query"}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_should_pass_with_valid_attachments(self):
"""Test that hit_testing_args_check passes with valid attachment_ids"""
# Arrange
args = {"attachment_ids": ["id1", "id2"]}
# Act & Assert (should not raise)
HitTestingService.hit_testing_args_check(args)
def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self):
"""Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing"""
# Arrange
args = {}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query or attachment_ids is required" in str(exc_info.value)
def test_hit_testing_args_check_should_raise_error_when_query_too_long(self):
"""Test that hit_testing_args_check raises ValueError if query exceeds 250 characters"""
# Arrange
args = {"query": "a" * 251}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Query cannot exceed 250 characters" in str(exc_info.value)
def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self):
"""Test that hit_testing_args_check raises ValueError if attachment_ids is not a list"""
# Arrange
args = {"attachment_ids": "not a list"}
# Act & Assert
with pytest.raises(ValueError) as exc_info:
HitTestingService.hit_testing_args_check(args)
assert "Attachment_ids must be a list" in str(exc_info.value)
# ===== Response Formatting Tests =====
@patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents")
def test_compact_retrieve_response_should_format_correctly(self, mock_format):
"""Test that compact_retrieve_response formats the response correctly"""
# Arrange
query = "test query"
mock_doc = MagicMock(spec=Document)
documents = [mock_doc]
mock_record = MagicMock()
mock_record.model_dump.return_value = {"content": "formatted content"}
mock_format.return_value = [mock_record]
# Act
result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 1
assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content"
mock_format.assert_called_once_with(documents)
def test_compact_external_retrieve_response_should_return_records_for_external_provider(self):
"""Test that compact_external_retrieve_response returns records when dataset provider is external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "external"
query = "test query"
documents = [
{"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}},
{"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}},
]
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert len(result["records"]) == 2
assert cast(dict[str, Any], result["records"][0])["content"] == "c1"
assert cast(dict[str, Any], result["records"][1])["title"] == "t2"
def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self):
"""Test that compact_external_retrieve_response returns empty records for non-external provider"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
documents = [{"content": "c1"}]
# Act
result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert result["records"] == []
# ===== External Retrieve Tests =====
@patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve):
"""Test that external_retrieve successfully retrieves from external provider and commits query"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.provider = "external"
query = 'test "query"'
account = MagicMock()
account.id = "account_id"
mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}]
# Act
result = cast(
dict[str, Any],
HitTestingService.external_retrieve(
dataset=dataset,
query=query,
account=account,
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert cast(dict[str, Any], result["records"][0])["content"] == "ext content"
# Verify call to RetrievalService.external_retrieve with escaped query
mock_ext_retrieve.assert_called_once_with(
dataset_id="dataset_id",
query='test \\"query\\"',
external_retrieval_model={"model": "test"},
metadata_filtering_conditions={"key": "val"},
)
# Verify DatasetQuery record was added and committed
mock_add.assert_called_once()
mock_commit.assert_called_once()
def test_external_retrieve_should_return_empty_for_non_external_provider(self):
"""Test that external_retrieve returns empty results immediately if provider is not external"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.provider = "not_external"
query = "test query"
account = MagicMock()
# Act
result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account))
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
assert result["records"] == []
# ===== Retrieve Tests =====
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve uses default model when retrieval_model is not provided"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
dataset.retrieval_model = None
query = "test query"
account = MagicMock()
account.id = "account_id"
mock_retrieve.return_value = []
# Act
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={}
),
)
# Assert
assert cast(dict[str, Any], result["query"])["content"] == query
mock_retrieve.assert_called_once()
# Verify top_k from default_retrieval_model (4)
assert mock_retrieve.call_args.kwargs["top_k"] == 4
mock_commit.assert_called_once()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve):
"""Test that retrieve correctly calls metadata filtering when conditions are present"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock()
account.id = "account_id"
retrieval_model = {
"search_method": "semantic_search",
"metadata_filtering_conditions": {"some": "condition"},
"top_k": 5,
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response
mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string")
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
)
# Assert
mock_get_meta.assert_called_once()
mock_retrieve.assert_called_once()
assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"]
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition")
def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve):
"""Test that retrieve returns empty response if metadata filtering returns condition but no document IDs"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock()
retrieval_model = {
"search_method": "semantic_search",
"metadata_filtering_conditions": {"some": "condition"},
"top_k": 5,
"reranking_enable": False,
"score_threshold_enabled": False,
}
# Mock metadata filtering response: condition returned but no IDs
mock_get_meta.return_value = ({}, "condition_string")
# Act
result = cast(
dict[str, Any],
HitTestingService.retrieve(
dataset=dataset,
query=query,
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
),
)
# Assert
assert result["records"] == []
mock_retrieve.assert_not_called()
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve handles attachment_ids and adds them to DatasetQuery"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock()
account.id = "account_id"
attachment_ids = ["att1", "att2"]
retrieval_model = {
"search_method": "semantic_search",
"top_k": 4,
"reranking_enable": False,
"score_threshold_enabled": False,
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset,
query=query,
account=account,
retrieval_model=retrieval_model,
external_retrieval_model={},
attachment_ids=attachment_ids,
)
# Assert
mock_retrieve.assert_called_once_with(
retrieval_method=ANY,
dataset_id="dataset_id",
query=query,
attachment_ids=attachment_ids,
top_k=4,
score_threshold=0.0,
reranking_model=None,
reranking_mode="reranking_model",
weights=None,
document_ids_filter=None,
)
# Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images)
# The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}])
called_query = mock_add.call_args[0][0]
query_content = json.loads(called_query.content)
assert len(query_content) == 3 # 1 text + 2 images
assert query_content[0]["content_type"] == "text_query"
assert query_content[1]["content_type"] == "image_query"
assert query_content[1]["content"] == "att1"
@patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve")
@patch("extensions.ext_database.db.session.add")
@patch("extensions.ext_database.db.session.commit")
def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve):
"""Test that retrieve passes reranking and threshold parameters correctly"""
# Arrange
dataset = MagicMock(spec=Dataset)
dataset.id = "dataset_id"
query = "test query"
account = MagicMock()
account.id = "account_id"
retrieval_model = {
"search_method": "hybrid_search",
"top_k": 10,
"reranking_enable": True,
"reranking_model": {"provider": "test"},
"reranking_mode": "weighted_sum",
"score_threshold_enabled": True,
"score_threshold": 0.5,
"weights": {"vector": 0.5, "keyword": 0.5},
}
mock_retrieve.return_value = []
# Act
HitTestingService.retrieve(
dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={}
)
# Assert
mock_retrieve.assert_called_once()
kwargs = mock_retrieve.call_args.kwargs
assert kwargs["score_threshold"] == 0.5
assert kwargs["reranking_model"] == {"provider": "test"}
assert kwargs["reranking_mode"] == "weighted_sum"
assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5}

View File

@ -0,0 +1,146 @@
from typing import Any, cast
from unittest.mock import MagicMock, patch
import pytest
from services.knowledge_service import ExternalDatasetTestService
class TestKnowledgeService:
"""Test suite for ExternalDatasetTestService"""
# ===== Happy Path Tests =====
@patch("services.knowledge_service.boto3.client")
@patch("services.knowledge_service.dify_config")
def test_knowledge_retrieval_should_succeed_with_valid_results(
self, mock_dify_config: MagicMock, mock_boto_client: MagicMock
):
"""Test that knowledge_retrieval successfully parses results from Bedrock"""
# Arrange
mock_dify_config.AWS_SECRET_ACCESS_KEY = "dummy_secret"
mock_dify_config.AWS_ACCESS_KEY_ID = "dummy_id"
mock_client = MagicMock()
mock_boto_client.return_value = mock_client
retrieval_setting = {"top_k": 4, "score_threshold": 0.5}
query = "test query"
knowledge_id = "kb-123"
# Mock successful response
mock_client.retrieve.return_value = {
"ResponseMetadata": {"HTTPStatusCode": 200},
"retrievalResults": [
{
"score": 0.9,
"metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc1.pdf"},
"content": {"text": "content from doc1"},
},
{
"score": 0.4, # Below threshold
"metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc2.pdf"},
"content": {"text": "content from doc2"},
},
],
}
# Act
result = cast(
dict[str, Any], ExternalDatasetTestService.knowledge_retrieval(retrieval_setting, query, knowledge_id)
)
# Assert
assert len(result["records"]) == 1
record = result["records"][0]
assert record["score"] == 0.9
assert record["title"] == "s3://bucket/doc1.pdf"
assert record["content"] == "content from doc1"
# verify retrieve called correctly
mock_client.retrieve.assert_called_once_with(
knowledgeBaseId=knowledge_id,
retrievalConfiguration={
"vectorSearchConfiguration": {
"numberOfResults": 4,
"overrideSearchType": "HYBRID",
}
},
retrievalQuery={"text": query},
)
# NEW: verify boto3.client created with proper service name and config values
mock_boto_client.assert_called_once_with(
"bedrock-agent-runtime",
aws_secret_access_key="dummy_secret",
aws_access_key_id="dummy_id",
region_name="us-east-1",
)
@patch("services.knowledge_service.boto3.client")
def test_knowledge_retrieval_should_return_empty_when_no_results(self, mock_boto: MagicMock):
"""Test that knowledge_retrieval returns empty records when Bedrock returns nothing"""
# Arrange
mock_client = MagicMock()
mock_boto.return_value = mock_client
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []}
# Act
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
# Assert
assert result["records"] == []
# ===== Error Handling Tests =====
@patch("services.knowledge_service.boto3.client")
def test_knowledge_retrieval_should_return_empty_on_http_error(self, mock_boto: MagicMock):
"""Test that knowledge_retrieval returns empty records if Bedrock returns non-200 status"""
# Arrange
mock_client = MagicMock()
mock_boto.return_value = mock_client
mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}}
# Act
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
# Assert
assert result["records"] == []
def test_knowledge_retrieval_should_raise_when_boto_client_creation_fails(self):
"""Test that exceptions from boto3.client propagate (e.g., network/credentials issues)"""
with patch("services.knowledge_service.boto3.client") as mock_boto:
mock_boto.side_effect = Exception("client init failed")
with pytest.raises(Exception) as exc_info:
ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")
assert "client init failed" in str(exc_info.value)
# ===== Edge Cases =====
@patch("services.knowledge_service.boto3.client")
def test_knowledge_retrieval_should_handle_missing_threshold_in_settings(self, mock_boto: MagicMock):
"""Test that knowledge_retrieval uses 0.0 as default threshold if not provided"""
# Arrange
mock_client = MagicMock()
mock_boto.return_value = mock_client
mock_client.retrieve.return_value = {
"ResponseMetadata": {"HTTPStatusCode": 200},
"retrievalResults": [
{
"score": 0.1,
"metadata": {"x-amz-bedrock-kb-source-uri": "uri"},
"content": {"text": "text"},
}
],
}
# Act
# retrieval_setting missing "score_threshold"
result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb"))
# Assert
assert len(result["records"]) == 1
assert result["records"][0]["score"] == 0.1

View File

@ -0,0 +1,120 @@
from unittest.mock import MagicMock, patch
import httpx
import pytest
from services.operation_service import OperationService
class TestOperationService:
"""Test suite for OperationService"""
# ===== Internal Method Tests =====
@patch("httpx.request")
def test_should_call_with_correct_parameters_when__send_request_invoked(
self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch
):
"""Test that _send_request calls httpx.request with the correct URL, headers, and data"""
# Arrange
monkeypatch.setattr(OperationService, "base_url", "https://billing.example")
monkeypatch.setattr(OperationService, "secret_key", "s3cr3t")
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success"}
mock_request.return_value = mock_response
method = "POST"
endpoint = "/test_endpoint"
json_data = {"key": "value"}
# Act
result = OperationService._send_request(method, endpoint, json=json_data)
# Assert
assert result == {"status": "success"}
# Verify call parameters
expected_url = "https://billing.example/test_endpoint"
mock_request.assert_called_once()
args, kwargs = mock_request.call_args
assert args[0] == method
assert args[1] == expected_url
assert kwargs["json"] == json_data
assert kwargs["headers"]["Billing-Api-Secret-Key"] == "s3cr3t"
assert kwargs["headers"]["Content-Type"] == "application/json"
@patch("httpx.request")
def test_should_propagate_httpx_error_when__send_request_raises(
self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch
):
"""Test that _send_request handles httpx raising an error"""
# Arrange
monkeypatch.setattr(OperationService, "base_url", "https://billing.example")
mock_request.side_effect = httpx.RequestError("network error")
# Act & Assert
with pytest.raises(httpx.RequestError):
OperationService._send_request("POST", "/test")
# ===== Public Method Tests =====
@pytest.mark.parametrize(
("utm_info", "expected_params"),
[
(
{
"utm_source": "google",
"utm_medium": "cpc",
"utm_campaign": "spring_sale",
"utm_content": "ad_1",
"utm_term": "ai_agent",
},
{
"tenant_id": "tenant-123",
"utm_source": "google",
"utm_medium": "cpc",
"utm_campaign": "spring_sale",
"utm_content": "ad_1",
"utm_term": "ai_agent",
},
),
(
{}, # Empty utm_info
{
"tenant_id": "tenant-123",
"utm_source": "",
"utm_medium": "",
"utm_campaign": "",
"utm_content": "",
"utm_term": "",
},
),
(
{"utm_source": "newsletter"}, # Partial utm_info
{
"tenant_id": "tenant-123",
"utm_source": "newsletter",
"utm_medium": "",
"utm_campaign": "",
"utm_content": "",
"utm_term": "",
},
),
],
)
@patch.object(OperationService, "_send_request")
def test_should_map_parameters_correctly_when_record_utm_called(
self, mock_send: MagicMock, utm_info: dict, expected_params: dict
):
"""Test that record_utm correctly maps utm_info to parameters and calls _send_request"""
# Arrange
tenant_id = "tenant-123"
mock_send.return_value = {"status": "recorded"}
# Act
result = OperationService.record_utm(tenant_id, utm_info)
# Assert
assert result == {"status": "recorded"}
mock_send.assert_called_once_with("POST", "/tenant_utms", params=expected_params)