diff --git a/api/tests/unit_tests/services/test_advanced_prompt_template_service.py b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py new file mode 100644 index 0000000000..a6bc79e82b --- /dev/null +++ b/api/tests/unit_tests/services/test_advanced_prompt_template_service.py @@ -0,0 +1,214 @@ +""" +Unit tests for services.advanced_prompt_template_service +""" + +import copy + +from core.prompt.prompt_templates.advanced_prompt_templates import ( + BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG, + BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + BAICHUAN_CONTEXT, + CHAT_APP_CHAT_PROMPT_CONFIG, + CHAT_APP_COMPLETION_PROMPT_CONFIG, + COMPLETION_APP_CHAT_PROMPT_CONFIG, + COMPLETION_APP_COMPLETION_PROMPT_CONFIG, + CONTEXT, +) +from models.model import AppMode +from services.advanced_prompt_template_service import AdvancedPromptTemplateService + + +class TestAdvancedPromptTemplateService: + """Test suite for AdvancedPromptTemplateService.""" + + def test_get_prompt_should_use_baichuan_prompt_when_model_name_contains_baichuan(self) -> None: + """Test baichuan model names use baichuan context prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "chat", + "model_name": "Baichuan2-13B", + "has_context": "true", + } + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + + def test_get_prompt_should_use_common_prompt_when_model_name_not_baichuan(self) -> None: + """Test non-baichuan model names use common prompt.""" + # Arrange + args = { + "app_mode": AppMode.CHAT, + "model_mode": "completion", + "model_name": "gpt-4", + "has_context": "false", + } + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_prompt(args) + + # Assert + assert result == original_config + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_app_mode_invalid(self) -> None: + """Test invalid app mode returns empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "chat" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} + + def test_get_common_prompt_should_prepend_context_for_completion_prompt(self) -> None: + """Test context is prepended for completion prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(CONTEXT) + assert original_config == CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_prepend_context_for_chat_prompt(self) -> None: + """Test context is prepended for chat prompt when has_context is true.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(CONTEXT) + assert original_config == COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_chat_prompt_without_context_when_has_context_false(self) -> None: + """Test chat prompt remains unchanged when has_context is false.""" + # Arrange + original_config = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.CHAT, "chat", "false") + + # Assert + assert result == original_config + assert original_config == CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_common_prompt_should_return_completion_prompt_for_completion_app_mode(self) -> None: + """Test completion app mode with completion model returns completion prompt.""" + # Arrange + original_config = copy.deepcopy(COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(AppMode.COMPLETION, "completion", "false") + + # Assert + assert result == original_config + assert original_config == COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_common_prompt_should_return_empty_dict_when_model_mode_invalid(self) -> None: + """Test invalid model mode returns empty dict.""" + # Arrange + app_mode = AppMode.CHAT + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_common_prompt(app_mode, model_mode, "false") + + # Assert + assert result == {} + + def test_get_completion_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps completion prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_COMPLETION_PROMPT_CONFIG) + original_text = prompt_template["completion_prompt_config"]["prompt"]["text"] + + # Act + result = AdvancedPromptTemplateService.get_completion_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"] == original_text + + def test_get_chat_prompt_should_not_prepend_context_when_has_context_false(self) -> None: + """Test helper keeps chat prompt unchanged when context is disabled.""" + # Arrange + prompt_template = copy.deepcopy(CHAT_APP_CHAT_PROMPT_CONFIG) + original_text = prompt_template["chat_prompt_config"]["prompt"][0]["text"] + + # Act + result = AdvancedPromptTemplateService.get_chat_prompt(prompt_template, "false", CONTEXT) + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"] == original_text + + def test_get_baichuan_prompt_should_return_chat_completion_config_when_chat_completion(self) -> None: + """Test baichuan chat/completion returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "completion", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_chat_config_when_completion_chat(self) -> None: + """Test baichuan completion/chat returns the expected config.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "chat", "false") + + # Assert + assert result == original_config + assert original_config == BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_completion_completion_config_when_enabled_context(self) -> None: + """Test baichuan completion/completion prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.COMPLETION, "completion", "true") + + # Assert + assert result["completion_prompt_config"]["prompt"]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_chat_chat_config_when_enabled_context(self) -> None: + """Test baichuan chat/chat prepends baichuan context when enabled.""" + # Arrange + original_config = copy.deepcopy(BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG) + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(AppMode.CHAT, "chat", "true") + + # Assert + assert result["chat_prompt_config"]["prompt"][0]["text"].startswith(BAICHUAN_CONTEXT) + assert original_config == BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG + + def test_get_baichuan_prompt_should_return_empty_dict_when_invalid_inputs(self) -> None: + """Test invalid baichuan mode combinations return empty dict.""" + # Arrange + app_mode = "invalid" + model_mode = "invalid" + + # Act + result = AdvancedPromptTemplateService.get_baichuan_prompt(app_mode, model_mode, "true") + + # Assert + assert result == {} diff --git a/api/tests/unit_tests/services/test_agent_service.py b/api/tests/unit_tests/services/test_agent_service.py new file mode 100644 index 0000000000..7ce3d7ef7b --- /dev/null +++ b/api/tests/unit_tests/services/test_agent_service.py @@ -0,0 +1,346 @@ +""" +Unit tests for services.agent_service +""" + +from collections.abc import Callable +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +import pytz + +from core.plugin.impl.exc import PluginDaemonClientSideError +from models import Account +from models.model import App, Conversation, EndUser, Message, MessageAgentThought +from services.agent_service import AgentService + + +def _make_current_user_account(timezone: str = "UTC") -> Account: + account = Account(name="Test User", email="test@example.com") + account.timezone = timezone + return account + + +def _make_app_model(app_model_config: MagicMock | None) -> MagicMock: + app_model = MagicMock(spec=App) + app_model.id = "app-123" + app_model.tenant_id = "tenant-123" + app_model.app_model_config = app_model_config + return app_model + + +def _make_conversation(from_end_user_id: str | None, from_account_id: str | None) -> MagicMock: + conversation = MagicMock(spec=Conversation) + conversation.id = "conv-123" + conversation.app_id = "app-123" + conversation.from_end_user_id = from_end_user_id + conversation.from_account_id = from_account_id + return conversation + + +def _make_message(agent_thoughts: list[MessageAgentThought]) -> MagicMock: + message = MagicMock(spec=Message) + message.id = "msg-123" + message.conversation_id = "conv-123" + message.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + message.provider_response_latency = 1.23 + message.answer_tokens = 4 + message.message_tokens = 6 + message.agent_thoughts = agent_thoughts + message.message_files = ["file-a.txt"] + return message + + +def _make_agent_thought() -> MagicMock: + agent_thought = MagicMock(spec=MessageAgentThought) + agent_thought.tokens = 3 + agent_thought.tool_input = "raw-input" + agent_thought.observation = "raw-output" + agent_thought.thought = "thinking" + agent_thought.created_at = datetime(2024, 1, 1, tzinfo=pytz.UTC) + agent_thought.files = [] + agent_thought.tools = ["tool_a", "dataset_tool"] + agent_thought.tool_labels = {"tool_a": "Tool A"} + agent_thought.tool_meta = { + "tool_a": { + "tool_config": { + "tool_provider_type": "custom", + "tool_provider": "provider-1", + }, + "tool_parameters": {"param": "value"}, + "time_cost": 2.5, + }, + "dataset_tool": { + "tool_config": { + "tool_provider_type": "dataset-retrieval", + "tool_provider": "dataset-provider", + } + }, + } + agent_thought.tool_inputs_dict = {"tool_a": {"q": "hello"}, "dataset_tool": {"k": "v"}} + agent_thought.tool_outputs_dict = {"tool_a": {"result": "ok"}} + return agent_thought + + +def _build_query_side_effect( + conversation: Conversation | None, + message: Message | None, + executor: EndUser | Account | None, +) -> Callable[..., MagicMock]: + def _query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if any(arg is Conversation for arg in args): + query.first.return_value = conversation + elif any(arg is Message for arg in args): + query.first.return_value = message + elif any(arg is EndUser for arg in args) or any(arg is Account for arg in args): + query.first.return_value = executor + return query + + return _query_side_effect + + +class TestAgentServiceGetAgentLogs: + """Test suite for AgentService.get_agent_logs.""" + + def test_get_agent_logs_should_raise_when_conversation_missing(self) -> None: + """Test missing conversation raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + with patch("services.agent_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, "missing-conv", "msg-1") + + def test_get_agent_logs_should_raise_when_message_missing(self) -> None: + """Test missing message raises ValueError.""" + # Arrange + app_model = _make_app_model(MagicMock()) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + with patch("services.agent_service.db") as mock_db: + conversation_query = MagicMock() + conversation_query.where.return_value = conversation_query + conversation_query.first.return_value = conversation + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [conversation_query, message_query] + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, "missing-msg") + + def test_get_agent_logs_should_raise_when_app_model_config_missing(self) -> None: + """Test missing app model config raises ValueError.""" + # Arrange + app_model = _make_app_model(None) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with patch("services.agent_service.db") as mock_db, patch("services.agent_service.current_user", current_user): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_raise_when_agent_config_missing(self) -> None: + """Test missing agent config raises ValueError.""" + # Arrange + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + message = _make_message([]) + current_user = _make_current_user_account() + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, MagicMock()) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_logs(app_model, conversation.id, message.id) + + def test_get_agent_logs_should_return_logs_for_end_user_executor(self) -> None: + """Test agent logs returned for end-user executor with tool icons.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + executor = MagicMock(spec=EndUser) + executor.name = "End User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_tool = MagicMock() + agent_tool.tool_name = "tool_a" + agent_tool.provider_type = "custom" + agent_tool.provider_id = "provider-2" + agent_config = MagicMock() + agent_config.tools = [agent_tool] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config) as mock_convert, + patch("services.agent_service.ToolManager.get_tool_icon") as mock_get_icon, + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + mock_get_icon.side_effect = [None, "icon-a"] + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["status"] == "success" + assert result["meta"]["executor"] == "End User" + assert result["meta"]["total_tokens"] == 10 + assert result["meta"]["agent_mode"] == "react" + assert result["meta"]["iterations"] == 1 + assert result["files"] == ["file-a.txt"] + assert len(result["iterations"]) == 1 + tool_calls = result["iterations"][0]["tool_calls"] + assert tool_calls[0]["tool_name"] == "tool_a" + assert tool_calls[0]["tool_icon"] == "icon-a" + assert tool_calls[1]["tool_name"] == "dataset_tool" + assert tool_calls[1]["tool_icon"] == "" + mock_convert.assert_called_once() + + def test_get_agent_logs_should_return_account_executor_when_no_end_user(self) -> None: + """Test agent logs fall back to account executor when end user is missing.""" + # Arrange + agent_thought = _make_agent_thought() + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id=None, from_account_id="account-1") + executor = MagicMock(spec=Account) + executor.name = "Account User" + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {"strategy": "react"} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=""), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, executor) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Account User" + + def test_get_agent_logs_should_use_defaults_when_executor_and_tool_data_missing(self) -> None: + """Test unknown executor and missing tool details fall back to defaults.""" + # Arrange + agent_thought = _make_agent_thought() + agent_thought.tool_labels = {} + agent_thought.tool_inputs_dict = {} + agent_thought.tool_outputs_dict = None + agent_thought.tool_meta = {"tool_a": {"error": "failed"}} + agent_thought.tools = ["tool_a"] + + message = _make_message([agent_thought]) + conversation = _make_conversation(from_end_user_id="end-user-1", from_account_id=None) + app_model_config = MagicMock() + app_model_config.agent_mode_dict = {} + app_model_config.to_dict.return_value = {"tools": []} + app_model = _make_app_model(app_model_config) + current_user = _make_current_user_account() + agent_config = MagicMock() + agent_config.tools = [] + + with ( + patch("services.agent_service.db") as mock_db, + patch("services.agent_service.AgentConfigManager.convert", return_value=agent_config), + patch("services.agent_service.ToolManager.get_tool_icon", return_value=None), + patch("services.agent_service.current_user", current_user), + ): + mock_db.session.query.side_effect = _build_query_side_effect(conversation, message, None) + + # Act + result = AgentService.get_agent_logs(app_model, conversation.id, message.id) + + # Assert + assert result["meta"]["executor"] == "Unknown" + assert result["meta"]["agent_mode"] == "react" + tool_call = result["iterations"][0]["tool_calls"][0] + assert tool_call["status"] == "error" + assert tool_call["error"] == "failed" + assert tool_call["tool_label"] == "tool_a" + assert tool_call["tool_input"] == {} + assert tool_call["tool_output"] == {} + assert tool_call["time_cost"] == 0 + assert tool_call["tool_parameters"] == {} + assert tool_call["tool_icon"] is None + + +class TestAgentServiceProviders: + """Test suite for AgentService provider methods.""" + + def test_list_agent_providers_should_delegate_to_plugin_client(self) -> None: + """Test list_agent_providers delegates to PluginAgentClient.""" + # Arrange + tenant_id = "tenant-1" + expected = [{"name": "provider"}] + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_providers.return_value = expected + + # Act + result = AgentService.list_agent_providers("user-1", tenant_id) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_providers.assert_called_once_with(tenant_id) + + def test_get_agent_provider_should_return_provider_when_successful(self) -> None: + """Test get_agent_provider returns provider when successful.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + expected = {"name": provider_name} + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.return_value = expected + + # Act + result = AgentService.get_agent_provider("user-1", tenant_id, provider_name) + + # Assert + assert result == expected + mock_client.return_value.fetch_agent_strategy_provider.assert_called_once_with(tenant_id, provider_name) + + def test_get_agent_provider_should_raise_value_error_on_plugin_error(self) -> None: + """Test get_agent_provider wraps PluginDaemonClientSideError into ValueError.""" + # Arrange + tenant_id = "tenant-1" + provider_name = "provider-a" + with patch("services.agent_service.PluginAgentClient") as mock_client: + mock_client.return_value.fetch_agent_strategy_provider.side_effect = PluginDaemonClientSideError( + "plugin error" + ) + + # Act & Assert + with pytest.raises(ValueError): + AgentService.get_agent_provider("user-1", tenant_id, provider_name) diff --git a/api/tests/unit_tests/services/test_annotation_service.py b/api/tests/unit_tests/services/test_annotation_service.py new file mode 100644 index 0000000000..0aacfc7f13 --- /dev/null +++ b/api/tests/unit_tests/services/test_annotation_service.py @@ -0,0 +1,1685 @@ +""" +Unit tests for services.annotation_service +""" + +from io import BytesIO +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pandas as pd +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import NotFound + +from models.model import App, AppAnnotationHitHistory, AppAnnotationSetting, Message, MessageAnnotation +from services.annotation_service import AppAnnotationService + + +def _make_app(app_id: str = "app-1", tenant_id: str = "tenant-1") -> MagicMock: + app = MagicMock(spec=App) + app.id = app_id + app.tenant_id = tenant_id + app.status = "normal" + return app + + +def _make_user(user_id: str = "user-1") -> MagicMock: + user = MagicMock() + user.id = user_id + return user + + +def _make_message(message_id: str = "msg-1", app_id: str = "app-1") -> MagicMock: + message = MagicMock(spec=Message) + message.id = message_id + message.app_id = app_id + message.conversation_id = "conv-1" + message.query = "default-question" + message.annotation = None + return message + + +def _make_annotation(annotation_id: str = "ann-1") -> MagicMock: + annotation = MagicMock(spec=MessageAnnotation) + annotation.id = annotation_id + annotation.content = "" + annotation.question = "" + annotation.question_text = "" + return annotation + + +def _make_setting(setting_id: str = "setting-1", with_detail: bool = True) -> MagicMock: + setting = MagicMock(spec=AppAnnotationSetting) + setting.id = setting_id + setting.score_threshold = 0.5 + setting.collection_binding_id = "collection-1" + if with_detail: + setting.collection_binding_detail = SimpleNamespace(provider_name="provider-a", model_name="model-a") + else: + setting.collection_binding_detail = None + return setting + + +def _make_file(content: bytes) -> FileStorage: + return FileStorage(stream=BytesIO(content)) + + +class TestAppAnnotationServiceUpInsert: + """Test suite for up_insert_app_annotation_from_message.""" + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, "app-1") + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_answer_missing(self) -> None: + """Test missing answer and content raises ValueError.""" + # Arrange + args = {"message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_raise_not_found_when_message_missing(self) -> None: + """Test missing message raises NotFound.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_update_existing_annotation_when_found(self) -> None: + """Test existing annotation is updated and indexed.""" + # Arrange + args = {"answer": "updated", "message_id": "msg-1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = annotation + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation + assert annotation.content == "updated" + assert annotation.question == message.query + mock_db.session.add.assert_called_once_with(annotation) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + message.query, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_has_no_annotation( + self, + ) -> None: + """Test new annotation is created when message has no annotation.""" + # Arrange + args = {"answer": "hello", "message_id": "msg-1", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + message = _make_message(message_id="msg-1", app_id=app.id) + message.annotation = None + annotation_instance = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + message_query = MagicMock() + message_query.where.return_value = message_query + message_query.first.return_value = message + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, message_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + conversation_id=message.conversation_id, + message_id=message.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_not_called() + + def test_up_insert_app_annotation_from_message_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question without message_id raises ValueError.""" + # Arrange + args = {"answer": "hello"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + def test_up_insert_app_annotation_from_message_should_create_annotation_when_message_missing(self) -> None: + """Test annotation is created when message_id is not provided.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user() + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + +class TestAppAnnotationServiceEnableDisable: + """Test suite for enable/disable app annotation.""" + + def test_enable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test cache hit returns processing status.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-1" + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "job-1", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_enable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test cache miss enqueues enable task.""" + # Arrange + args = {"score_threshold": 0.5, "embedding_provider_name": "p", "embedding_model_name": "m"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-1"), + patch("services.annotation_service.enable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.enable_app_annotation(args, "app-1") + + # Assert + assert result == {"job_id": "uuid-1", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("enable_app_annotation_job_uuid-1", "waiting") + mock_task.delay.assert_called_once_with( + "uuid-1", + "app-1", + current_user.id, + tenant_id, + 0.5, + "p", + "m", + ) + + def test_disable_app_annotation_should_return_processing_when_cache_hit(self) -> None: + """Test disable cache hit returns processing status.""" + # Arrange + tenant_id = "tenant-1" + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = "job-2" + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "job-2", "job_status": "processing"} + mock_task.delay.assert_not_called() + + def test_disable_app_annotation_should_enqueue_job_when_cache_miss(self) -> None: + """Test disable cache miss enqueues disable task.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.uuid.uuid4", return_value="uuid-2"), + patch("services.annotation_service.disable_annotation_reply_task") as mock_task, + ): + mock_redis.get.return_value = None + + # Act + result = AppAnnotationService.disable_app_annotation("app-1") + + # Assert + assert result == {"job_id": "uuid-2", "job_status": "waiting"} + mock_redis.setnx.assert_called_once_with("disable_app_annotation_job_uuid-2", "waiting") + mock_task.delay.assert_called_once_with("uuid-2", "app-1", tenant_id) + + +class TestAppAnnotationServiceListAndExport: + """Test suite for list and export methods.""" + + def test_get_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_list_by_app_id("app-1", 1, 10, "") + + def test_get_annotation_list_by_app_id_should_return_items_with_keyword(self) -> None: + """Test keyword search returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1"], total=1) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("libs.helper.escape_like_pattern", return_value="safe"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "keyword") + + # Assert + assert items == ["a1"] + assert total == 1 + + def test_get_annotation_list_by_app_id_should_return_items_without_keyword(self) -> None: + """Test list query without keyword returns paginated items.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + pagination = SimpleNamespace(items=["a1", "a2"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_list_by_app_id(app.id, 1, 10, "") + + # Assert + assert items == ["a1", "a2"] + assert total == 2 + + def test_export_annotation_list_by_app_id_should_sanitize_fields(self) -> None: + """Test export sanitizes question and content fields.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation1.question = "=cmd" + annotation1.content = "+1" + annotation2 = _make_annotation("ann-2") + annotation2.question = "@bad" + annotation2.content = "-2" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.CSVSanitizer.sanitize_value", side_effect=lambda v: f"safe:{v}"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.order_by.return_value = annotation_query + annotation_query.all.return_value = [annotation1, annotation2] + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act + result = AppAnnotationService.export_annotation_list_by_app_id(app.id) + + # Assert + assert result == [annotation1, annotation2] + assert annotation1.question == "safe:=cmd" + assert annotation1.content == "safe:+1" + assert annotation2.question == "safe:@bad" + assert annotation2.content == "safe:-2" + + def test_export_annotation_list_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test export raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.export_annotation_list_by_app_id("app-1") + + +class TestAppAnnotationServiceDirectManipulation: + """Test suite for direct insert/update/delete methods.""" + + def test_insert_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test insert raises NotFound when app is missing.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.insert_app_annotation_directly(args, "app-1") + + def test_insert_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.insert_app_annotation_directly(args, app.id) + + def test_insert_app_annotation_directly_should_create_annotation_and_index(self) -> None: + """Test insert creates annotation and triggers index task.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + current_user = _make_user("user-1") + tenant_id = "tenant-1" + app = _make_app() + annotation_instance = _make_annotation("ann-1") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls, + patch("services.annotation_service.add_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.insert_app_annotation_directly(args, app.id) + + # Assert + assert result == annotation_instance + mock_cls.assert_called_once_with( + app_id=app.id, + content="hello", + question="q1", + account_id=current_user.id, + ) + mock_db.session.add.assert_called_once_with(annotation_instance) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation_instance.id, + "q1", + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_update_app_annotation_directly_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, app.id, "ann-1") + + def test_update_app_annotation_directly_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound in update path.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_directly(args, "app-1", "ann-1") + + def test_update_app_annotation_directly_should_raise_value_error_when_question_missing(self) -> None: + """Test missing question raises ValueError.""" + # Arrange + args = {"answer": "hello"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(ValueError): + AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + def test_update_app_annotation_directly_should_update_annotation_and_index(self) -> None: + """Test update changes fields and triggers index update.""" + # Arrange + args = {"answer": "hello", "question": "q1"} + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + annotation.question_text = "q1" + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.update_annotation_to_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id) + + # Assert + assert result == annotation + assert annotation.content == "hello" + assert annotation.question == "q1" + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + annotation.question_text, + tenant_id, + app.id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_delete_annotation_and_histories(self) -> None: + """Test delete removes annotation and hit histories.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + history1 = MagicMock(spec=AppAnnotationHitHistory) + history2 = MagicMock(spec=AppAnnotationHitHistory) + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + scalars_result = MagicMock() + scalars_result.all.return_value = [history1, history2] + + mock_db.session.query.side_effect = [app_query, annotation_query, setting_query] + mock_db.session.scalars.return_value = scalars_result + + # Act + AppAnnotationService.delete_app_annotation(app.id, annotation.id) + + # Assert + mock_db.session.delete.assert_any_call(annotation) + mock_db.session.delete.assert_any_call(history1) + mock_db.session.delete.assert_any_call(history2) + mock_db.session.commit.assert_called_once() + mock_task.delay.assert_called_once_with( + annotation.id, + app.id, + tenant_id, + setting.collection_binding_id, + ) + + def test_delete_app_annotation_should_raise_not_found_when_app_missing(self) -> None: + """Test delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation("app-1", "ann-1") + + def test_delete_app_annotation_should_raise_not_found_when_annotation_missing(self) -> None: + """Test delete raises NotFound when annotation is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotation(app.id, "ann-1") + + def test_delete_app_annotations_in_batch_should_return_zero_when_none_found(self) -> None: + """Test batch delete returns zero when no annotations found.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [] + + mock_db.session.query.side_effect = [app_query, annotations_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1"]) + + # Assert + assert result == {"deleted_count": 0} + + def test_delete_app_annotations_in_batch_should_raise_not_found_when_app_missing(self) -> None: + """Test batch delete raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.delete_app_annotations_in_batch("app-1", ["ann-1"]) + + def test_delete_app_annotations_in_batch_should_delete_annotations_and_histories(self) -> None: + """Test batch delete removes annotations and triggers index deletion.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + setting = _make_setting() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotations_query = MagicMock() + annotations_query.outerjoin.return_value = annotations_query + annotations_query.where.return_value = annotations_query + annotations_query.all.return_value = [(annotation1, setting), (annotation2, None)] + + hit_history_query = MagicMock() + hit_history_query.where.return_value = hit_history_query + hit_history_query.delete.return_value = None + + delete_query = MagicMock() + delete_query.where.return_value = delete_query + delete_query.delete.return_value = 2 + + mock_db.session.query.side_effect = [app_query, annotations_query, hit_history_query, delete_query] + + # Act + result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1", "ann-2"]) + + # Assert + assert result == {"deleted_count": 2} + mock_task.delay.assert_called_once_with(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + +class TestAppAnnotationServiceBatchImport: + """Test suite for batch import.""" + + def test_batch_import_app_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.batch_import_app_annotations("app-1", file) + + def test_batch_import_app_annotations_should_return_error_when_columns_invalid(self) -> None: + """Test invalid column count returns error message.""" + # Arrange + file = _make_file(b"question\nq\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["only"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Invalid CSV format" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_file_empty(self) -> None: + """Test empty file returns validation error before CSV parsing.""" + # Arrange + file = _make_file(b"") + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "empty or invalid" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_min_records_not_met(self) -> None: + """Test min records validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=2), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_row_limit_exceeded(self) -> None: + """Test row count over max limit returns explicit error.""" + # Arrange + file = _make_file(b"question,answer\nq1,a1\nq2,a2\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1", "q2"], "a": ["a1", "a2"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=1, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "too many records" in error_msg + + def test_batch_import_app_annotations_should_skip_malformed_rows_and_fail_min_records(self) -> None: + """Test malformed row extraction is skipped and can fail min record validation.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + malformed_row = MagicMock() + malformed_row.iloc.__getitem__.side_effect = IndexError() + df = MagicMock() + df.columns = ["q", "a"] + df.iterrows.return_value = [(0, malformed_row)] + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_skip_nan_rows_and_fail_min_records(self) -> None: + """Test NaN rows are skipped by validation and reported via min record check.""" + # Arrange + file = _make_file(b"question,answer\nnan,nan\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["nan"], "a": ["nan"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "at least" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_question_too_long(self) -> None: + """Test oversized question is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q" * 2001], "a": ["a"]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Question at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_answer_too_long(self) -> None: + """Test oversized answer is rejected with row context.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q"], "a": ["a" * 10001]}) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "Answer at row" in error_msg + + def test_batch_import_app_annotations_should_return_error_when_quota_exceeded(self) -> None: + """Test quota validation returns error message.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True), + annotation_quota_limit=SimpleNamespace(limit=1, size=1), + ) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + error_msg = cast(str, result["error_msg"]) + assert "exceeds the limit" in error_msg + + def test_batch_import_app_annotations_should_enqueue_job_when_valid(self) -> None: + """Test successful batch import enqueues job and returns status.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.batch_import_annotations_task") as mock_task, + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-3"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result == {"job_id": "uuid-3", "job_status": "waiting", "record_count": 1} + mock_redis.zadd.assert_called_once() + mock_redis.expire.assert_called_once() + mock_redis.setnx.assert_called_once_with("app_annotation_batch_import_uuid-3", "waiting") + mock_task.delay.assert_called_once() + + def test_batch_import_app_annotations_should_cleanup_active_job_on_unexpected_exception(self) -> None: + """Test unexpected runtime errors trigger cleanup and return wrapped error.""" + # Arrange + file = _make_file(b"question,answer\nq,a\n") + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + df = pd.DataFrame({"q": ["q1"], "a": ["a1"]}) + features = SimpleNamespace(billing=SimpleNamespace(enabled=False), annotation_quota_limit=None) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.pd.read_csv", return_value=df), + patch("services.annotation_service.FeatureService.get_features", return_value=features), + patch("services.annotation_service.redis_client") as mock_redis, + patch("services.annotation_service.uuid.uuid4", return_value="uuid-4"), + patch("services.annotation_service.naive_utc_now", return_value=SimpleNamespace(timestamp=lambda: 1)), + patch("services.annotation_service.logger") as mock_logger, + patch( + "configs.dify_config", + new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1), + ), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + mock_db.session.query.return_value = app_query + mock_redis.zadd.side_effect = RuntimeError("boom") + mock_redis.zrem.side_effect = RuntimeError("cleanup-failed") + + # Act + result = AppAnnotationService.batch_import_app_annotations(app.id, file) + + # Assert + assert result["error_msg"] == "An error occurred while processing the file: boom" + mock_redis.zrem.assert_called_once_with(f"annotation_import_active:{tenant_id}", "uuid-4") + mock_logger.debug.assert_called_once() + + +class TestAppAnnotationServiceHitHistoryAndSettings: + """Test suite for hit history and settings methods.""" + + def test_get_annotation_hit_histories_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories("app-1", "ann-1", 1, 10) + + def test_get_annotation_hit_histories_should_return_items_and_total(self) -> None: + """Test hit histories pagination returns items and total.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + annotation = _make_annotation("ann-1") + pagination = SimpleNamespace(items=["h1"], total=2) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = annotation + + mock_db.session.query.side_effect = [app_query, annotation_query] + mock_db.paginate.return_value = pagination + + # Act + items, total = AppAnnotationService.get_annotation_hit_histories(app.id, annotation.id, 1, 10) + + # Assert + assert items == ["h1"] + assert total == 2 + + def test_get_annotation_hit_histories_should_raise_not_found_when_annotation_missing(self) -> None: + """Test missing annotation raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + annotation_query = MagicMock() + annotation_query.where.return_value = annotation_query + annotation_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, annotation_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_annotation_hit_histories(app.id, "ann-1", 1, 10) + + def test_get_annotation_by_id_should_return_none_when_missing(self) -> None: + """Test get_annotation_by_id returns None when not found.""" + # Arrange + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result is None + + def test_get_annotation_by_id_should_return_annotation_when_exists(self) -> None: + """Test get_annotation_by_id returns annotation when found.""" + # Arrange + annotation = _make_annotation("ann-1") + with patch("services.annotation_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = annotation + mock_db.session.query.return_value = query + + # Act + result = AppAnnotationService.get_annotation_by_id("ann-1") + + # Assert + assert result == annotation + + def test_add_annotation_history_should_update_hit_count_and_store_history(self) -> None: + """Test add_annotation_history updates hit count and creates history.""" + # Arrange + with ( + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.AppAnnotationHitHistory") as mock_history_cls, + ): + query = MagicMock() + query.where.return_value = query + mock_db.session.query.return_value = query + + # Act + AppAnnotationService.add_annotation_history( + annotation_id="ann-1", + app_id="app-1", + annotation_question="q", + annotation_content="a", + query="q", + user_id="user-1", + message_id="msg-1", + from_source="chat", + score=0.8, + ) + + # Assert + query.update.assert_called_once() + mock_history_cls.assert_called_once() + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_get_app_annotation_setting_by_app_id_should_return_embedding_model_when_detail_exists(self) -> None: + """Test setting detail returns embedding model info.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=True) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + assert embedding_model["embedding_model_name"] == "model-a" + + def test_get_app_annotation_setting_by_app_id_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.get_app_annotation_setting_by_app_id("app-1") + + def test_get_app_annotation_setting_by_app_id_should_return_empty_embedding_model_when_no_detail(self) -> None: + """Test setting without detail returns empty embedding model.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting(with_detail=False) + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result["enabled"] is True + assert result["embedding_model"] == {} + + def test_get_app_annotation_setting_by_app_id_should_return_disabled_when_setting_missing(self) -> None: + """Test missing setting returns disabled payload.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id) + + # Assert + assert result == {"enabled": False} + + def test_update_app_annotation_setting_should_update_and_return_detail(self) -> None: + """Test update_app_annotation_setting updates fields and returns detail.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=True) + args = {"score_threshold": 0.8} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.8 + embedding_model = cast(dict[str, Any], result["embedding_model"]) + assert embedding_model["embedding_provider_name"] == "provider-a" + mock_db.session.add.assert_called_once_with(setting) + mock_db.session.commit.assert_called_once() + + def test_update_app_annotation_setting_should_return_empty_embedding_model_when_detail_missing(self) -> None: + """Test update returns empty embedding_model when collection detail is absent.""" + # Arrange + tenant_id = "tenant-1" + current_user = _make_user("user-1") + app = _make_app() + setting = _make_setting(with_detail=False) + args = {"score_threshold": 0.7} + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.naive_utc_now", return_value="now"), + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = setting + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act + result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args) + + # Assert + assert result["enabled"] is True + assert result["score_threshold"] == 0.7 + assert result["embedding_model"] == {} + + def test_update_app_annotation_setting_should_raise_not_found_when_app_missing(self) -> None: + """Test update raises NotFound when app is missing.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = None + mock_db.session.query.return_value = app_query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting("app-1", "setting-1", {"score_threshold": 0.5}) + + def test_update_app_annotation_setting_should_raise_not_found_when_setting_missing(self) -> None: + """Test update raises NotFound when setting is missing.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + app_query = MagicMock() + app_query.where.return_value = app_query + app_query.first.return_value = app + + setting_query = MagicMock() + setting_query.where.return_value = setting_query + setting_query.first.return_value = None + + mock_db.session.query.side_effect = [app_query, setting_query] + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.update_app_annotation_setting(app.id, "setting-1", {"score_threshold": 0.5}) + + +class TestAppAnnotationServiceClearAll: + """Test suite for clear_all_annotations.""" + + def test_clear_all_annotations_should_delete_annotations_and_histories(self) -> None: + """Test clear_all_annotations deletes all data and triggers index removal.""" + # Arrange + tenant_id = "tenant-1" + app = _make_app() + setting = _make_setting() + annotation1 = _make_annotation("ann-1") + annotation2 = _make_annotation("ann-2") + history = MagicMock(spec=AppAnnotationHitHistory) + + def query_side_effect(*args: object, **kwargs: object) -> MagicMock: + query = MagicMock() + query.where.return_value = query + if App in args: + query.first.return_value = app + elif AppAnnotationSetting in args: + query.first.return_value = setting + elif MessageAnnotation in args: + query.yield_per.return_value = [annotation1, annotation2] + elif AppAnnotationHitHistory in args: + query.yield_per.return_value = [history] + return query + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + patch("services.annotation_service.delete_annotation_index_task") as mock_task, + ): + mock_db.session.query.side_effect = query_side_effect + + # Act + result = AppAnnotationService.clear_all_annotations(app.id) + + # Assert + assert result == {"result": "success"} + mock_db.session.delete.assert_any_call(annotation1) + mock_db.session.delete.assert_any_call(annotation2) + mock_db.session.delete.assert_any_call(history) + mock_task.delay.assert_any_call(annotation1.id, app.id, tenant_id, setting.collection_binding_id) + mock_task.delay.assert_any_call(annotation2.id, app.id, tenant_id, setting.collection_binding_id) + mock_db.session.commit.assert_called_once() + + def test_clear_all_annotations_should_raise_not_found_when_app_missing(self) -> None: + """Test missing app raises NotFound.""" + # Arrange + tenant_id = "tenant-1" + + with ( + patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)), + patch("services.annotation_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(NotFound): + AppAnnotationService.clear_all_annotations("app-1") diff --git a/api/tests/unit_tests/services/test_app_service.py b/api/tests/unit_tests/services/test_app_service.py new file mode 100644 index 0000000000..bff8dc92c6 --- /dev/null +++ b/api/tests/unit_tests/services/test_app_service.py @@ -0,0 +1,609 @@ +"""Unit tests for services.app_service.""" + +import json +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.errors.error import ProviderTokenNotInitError +from models import Account, Tenant +from models.model import App, AppMode +from services.app_service import AppService + + +@pytest.fixture +def service() -> AppService: + """Provide AppService instance.""" + return AppService() + + +@pytest.fixture +def account() -> Account: + """Create account object for create_app tests.""" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + result = Account(name="Account User", email="account@example.com") + result.id = "acc-1" + result._current_tenant = tenant + return result + + +@pytest.fixture +def default_args() -> dict: + """Create default create_app args.""" + return { + "name": "Test App", + "mode": AppMode.CHAT.value, + "icon": "🤖", + "icon_background": "#FFFFFF", + } + + +@pytest.fixture +def app_template() -> dict: + """Create basic app template for create_app tests.""" + return { + AppMode.CHAT: { + "app": {}, + "model_config": { + "model": { + "provider": "provider-a", + "name": "model-a", + "mode": "chat", + "completion_params": {}, + } + }, + } + } + + +def _make_current_user() -> Account: + user = Account(name="Tester", email="tester@example.com") + user.id = "user-1" + tenant = Tenant(name="Tenant") + tenant.id = "tenant-1" + user._current_tenant = tenant + return user + + +class TestAppServicePagination: + """Test suite for get_paginate_apps.""" + + def test_get_paginate_apps_should_return_none_when_tag_filter_empty(self, service: AppService) -> None: + """Test pagination returns None when tag filter has no targets.""" + # Arrange + args = {"mode": "chat", "page": 1, "limit": 20, "tag_ids": ["tag-1"]} + + with patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=[]): + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is None + + def test_get_paginate_apps_should_delegate_to_db_paginate(self, service: AppService) -> None: + """Test pagination delegates to db.paginate when filters are valid.""" + # Arrange + args = { + "mode": "workflow", + "is_created_by_me": True, + "name": "My_App%", + "tag_ids": ["tag-1"], + "page": 2, + "limit": 10, + } + expected_pagination = MagicMock() + + with ( + patch("services.app_service.TagService.get_target_ids_by_tag_ids", return_value=["app-1"]), + patch("libs.helper.escape_like_pattern", return_value="escaped"), + patch("services.app_service.db") as mock_db, + ): + mock_db.paginate.return_value = expected_pagination + + # Act + result = service.get_paginate_apps("user-1", "tenant-1", args) + + # Assert + assert result is expected_pagination + mock_db.paginate.assert_called_once() + + +class TestAppServiceCreate: + """Test suite for create_app.""" + + def test_create_app_should_create_with_matching_default_model( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app uses matching default model and persists app config.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + model_instance = SimpleNamespace( + model_name="model-a", + provider="provider-a", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)) + mock_config.BILLING_ENABLED = True + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + assert app_instance.app_model_config_id == "cfg-1" + mock_db.session.add.assert_any_call(app_instance) + mock_db.session.add.assert_any_call(app_model_config) + assert mock_db.session.flush.call_count == 2 + mock_db.session.commit.assert_called_once() + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + + def test_create_app_should_raise_when_model_schema_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app raises ValueError when non-matching model has no schema.""" + # Arrange + app_instance = SimpleNamespace(id="app-1") + model_instance = SimpleNamespace( + model_name="model-b", + provider="provider-b", + model_type_instance=MagicMock(), + credentials={"k": "v"}, + ) + model_instance.model_type_instance.get_model_schema.return_value = None + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.return_value = model_instance + + # Act & Assert + with pytest.raises(ValueError, match="model schema not found"): + service.create_app("tenant-1", default_args, account) + mock_db.session.commit.assert_not_called() + + def test_create_app_should_fallback_to_default_provider_when_model_missing( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test create_app falls back to provider/model name when no default model instance is available.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db") as mock_db, + patch("services.app_service.app_was_created") as mock_event, + patch("services.app_service.FeatureService.get_system_features") as mock_features, + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch("services.app_service.dify_config") as mock_config, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = ProviderTokenNotInitError("not ready") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + mock_features.return_value = SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)) + mock_config.BILLING_ENABLED = False + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_event.send.assert_called_once_with(app_instance, account=account) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.update_app_access_mode.assert_called_once_with("app-1", "private") + + def test_create_app_should_log_and_fallback_on_unexpected_model_error( + self, + service: AppService, + account: Account, + default_args: dict, + app_template: dict, + ) -> None: + """Test unexpected model manager errors are logged and fallback provider is used.""" + # Arrange + app_instance = SimpleNamespace(id="app-1", tenant_id="tenant-1") + app_model_config = SimpleNamespace(id="cfg-1") + + with ( + patch("services.app_service.default_app_templates", app_template), + patch("services.app_service.App", return_value=app_instance), + patch("services.app_service.AppModelConfig", return_value=app_model_config), + patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.db"), + patch("services.app_service.app_was_created"), + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=False)), + ), + patch("services.app_service.dify_config", new=SimpleNamespace(BILLING_ENABLED=False)), + patch("services.app_service.logger") as mock_logger, + ): + manager = mock_model_manager.return_value + manager.get_default_model_instance.side_effect = RuntimeError("boom") + manager.get_default_provider_model_name.return_value = ("fallback-provider", "fallback-model") + + # Act + result = service.create_app("tenant-1", default_args, account) + + # Assert + assert result is app_instance + mock_logger.exception.assert_called_once() + + +class TestAppServiceGetAndUpdate: + """Test suite for app retrieval and update methods.""" + + def test_get_app_should_return_original_when_not_agent_app(self, service: AppService) -> None: + """Test get_app returns original app for non-agent modes.""" + # Arrange + app = MagicMock() + app.mode = AppMode.CHAT + app.is_agent = False + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_return_original_when_model_config_missing(self, service: AppService) -> None: + """Test get_app returns app when agent mode has no model config.""" + # Arrange + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = None + + with patch("services.app_service.current_user", _make_current_user()): + # Act + result = service.get_app(app) + + # Assert + assert result is app + + def test_get_app_should_mask_tool_parameters_for_agent_tools(self, service: AppService) -> None: + """Test get_app decrypts and masks secret tool parameters.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool, {"skip": True}]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + manager = MagicMock() + manager.decrypt_tool_parameters.return_value = {"secret": "decrypted"} + manager.mask_tool_parameters.return_value = {"secret": "***"} + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", return_value=MagicMock()), + patch("services.app_service.ToolParameterConfigurationManager", return_value=manager), + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + assert tool["tool_parameters"] == {"secret": "***"} + assert json.loads(model_config.agent_mode)["tools"][0]["tool_parameters"] == {"secret": "***"} + + def test_get_app_should_continue_when_tool_parameter_masking_fails(self, service: AppService) -> None: + """Test get_app logs and continues when masking fails.""" + # Arrange + tool = { + "provider_type": "builtin", + "provider_id": "provider-1", + "tool_name": "tool-a", + "tool_parameters": {"secret": "encrypted"}, + "extra": True, + } + model_config = MagicMock() + model_config.agent_mode_dict = {"tools": [tool]} + + app = MagicMock() + app.id = "app-1" + app.mode = AppMode.AGENT_CHAT + app.is_agent = False + app.app_model_config = model_config + + with ( + patch("services.app_service.current_user", _make_current_user()), + patch("services.app_service.ToolManager.get_agent_tool_runtime", side_effect=RuntimeError("mask-failed")), + patch("services.app_service.logger") as mock_logger, + ): + # Act + result = service.get_app(app) + + # Assert + assert result.app_model_config is model_config + mock_logger.exception.assert_called_once() + + def test_update_methods_should_mutate_app_and_commit(self, service: AppService) -> None: + """Test update methods set fields and commit changes.""" + # Arrange + app = cast( + App, + SimpleNamespace( + name="old", + description="old", + icon_type="emoji", + icon="a", + icon_background="#111", + enable_site=True, + enable_api=True, + ), + ) + args = { + "name": "new", + "description": "new-desc", + "icon_type": "image", + "icon": "new-icon", + "icon_background": "#222", + "use_icon_as_answer_icon": True, + "max_active_requests": 5, + } + user = SimpleNamespace(id="user-1") + + with ( + patch("services.app_service.current_user", user), + patch("services.app_service.db") as mock_db, + patch("services.app_service.naive_utc_now", return_value="now"), + ): + # Act + updated = service.update_app(app, args) + renamed = service.update_app_name(app, "rename") + iconed = service.update_app_icon(app, "icon-2", "#333") + site_same = service.update_app_site_status(app, app.enable_site) + api_same = service.update_app_api_status(app, app.enable_api) + site_changed = service.update_app_site_status(app, False) + api_changed = service.update_app_api_status(app, False) + + # Assert + assert updated is app + assert renamed is app + assert iconed is app + assert site_same is app + assert api_same is app + assert site_changed is app + assert api_changed is app + assert mock_db.session.commit.call_count >= 5 + + +class TestAppServiceDeleteAndMeta: + """Test suite for delete and metadata methods.""" + + def test_delete_app_should_cleanup_and_enqueue_task(self, service: AppService) -> None: + """Test delete_app removes app, runs cleanup, and triggers async deletion task.""" + # Arrange + app = cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + with ( + patch("services.app_service.db") as mock_db, + patch( + "services.app_service.FeatureService.get_system_features", + return_value=SimpleNamespace(webapp_auth=SimpleNamespace(enabled=True)), + ), + patch("services.app_service.EnterpriseService") as mock_enterprise, + patch( + "services.app_service.dify_config", + new=SimpleNamespace(BILLING_ENABLED=True, CONSOLE_API_URL="https://console.example"), + ), + patch("services.app_service.BillingService") as mock_billing, + patch("services.app_service.remove_app_and_related_data_task") as mock_task, + ): + # Act + service.delete_app(app) + + # Assert + mock_db.session.delete.assert_called_once_with(app) + mock_db.session.commit.assert_called_once() + mock_enterprise.WebAppAuth.cleanup_webapp.assert_called_once_with("app-1") + mock_billing.clean_billing_info_cache.assert_called_once_with("tenant-1") + mock_task.delay.assert_called_once_with(tenant_id="tenant-1", app_id="app-1") + + def test_get_app_meta_should_handle_workflow_and_tool_provider_icons(self, service: AppService) -> None: + """Test get_app_meta extracts builtin and API tool icons from workflow graph.""" + # Arrange + workflow = SimpleNamespace( + graph_dict={ + "nodes": [ + { + "data": { + "type": "tool", + "provider_type": "builtin", + "provider_id": "builtin-provider", + "tool_name": "tool_builtin", + } + }, + { + "data": { + "type": "tool", + "provider_type": "api", + "provider_id": "api-provider-id", + "tool_name": "tool_api", + } + }, + ] + } + ) + app = cast( + App, + SimpleNamespace( + mode=AppMode.WORKFLOW.value, + workflow=workflow, + app_model_config=None, + tenant_id="tenant-1", + icon_type="emoji", + icon_background="#fff", + ), + ) + + provider = SimpleNamespace(icon=json.dumps({"background": "#000", "content": "A"})) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = provider + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["tool_builtin"].endswith("/builtin-provider/icon") + assert meta["tool_icons"]["tool_api"] == {"background": "#000", "content": "A"} + + def test_get_app_meta_should_use_default_api_icon_on_lookup_error(self, service: AppService) -> None: + """Test get_app_meta falls back to default icon when API provider lookup fails.""" + # Arrange + app_model_config = SimpleNamespace( + agent_mode_dict={ + "tools": [{"provider_type": "api", "provider_id": "x", "tool_name": "t", "tool_parameters": {}}] + } + ) + app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=app_model_config, workflow=None)) + + with ( + patch("services.app_service.dify_config", new=SimpleNamespace(CONSOLE_API_URL="https://console.example")), + patch("services.app_service.db") as mock_db, + ): + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act + meta = service.get_app_meta(app) + + # Assert + assert meta["tool_icons"]["t"] == {"background": "#252525", "content": "\ud83d\ude01"} + + def test_get_app_meta_should_return_empty_when_required_data_missing(self, service: AppService) -> None: + """Test get_app_meta returns empty metadata when workflow/model config is absent.""" + # Arrange + workflow_app = cast(App, SimpleNamespace(mode=AppMode.WORKFLOW.value, workflow=None)) + chat_app = cast(App, SimpleNamespace(mode=AppMode.CHAT.value, app_model_config=None)) + + # Act + workflow_meta = service.get_app_meta(workflow_app) + chat_meta = service.get_app_meta(chat_app) + + # Assert + assert workflow_meta == {"tool_icons": {}} + assert chat_meta == {"tool_icons": {}} + + +class TestAppServiceCodeLookup: + """Test suite for app code lookup methods.""" + + def test_get_app_code_by_id_should_raise_when_site_missing(self) -> None: + """Test get_app_code_by_id raises when site is missing.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_code_by_id("app-1") + + def test_get_app_code_by_id_should_return_code(self) -> None: + """Test get_app_code_by_id returns site code.""" + # Arrange + site = SimpleNamespace(code="code-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_code_by_id("app-1") + + # Assert + assert result == "code-1" + + def test_get_app_id_by_code_should_raise_when_site_missing(self) -> None: + """Test get_app_id_by_code raises when code does not exist.""" + # Arrange + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = None + mock_db.session.query.return_value = query + + # Act & Assert + with pytest.raises(ValueError, match="not found"): + AppService.get_app_id_by_code("missing") + + def test_get_app_id_by_code_should_return_app_id(self) -> None: + """Test get_app_id_by_code returns linked app id.""" + # Arrange + site = SimpleNamespace(app_id="app-1") + with patch("services.app_service.db") as mock_db: + query = MagicMock() + query.where.return_value = query + query.first.return_value = site + mock_db.session.query.return_value = query + + # Act + result = AppService.get_app_id_by_code("code-1") + + # Assert + assert result == "app-1" diff --git a/api/tests/unit_tests/services/test_batch_indexing_base.py b/api/tests/unit_tests/services/test_batch_indexing_base.py new file mode 100644 index 0000000000..bd68b67d89 --- /dev/null +++ b/api/tests/unit_tests/services/test_batch_indexing_base.py @@ -0,0 +1,387 @@ +from dataclasses import asdict +from typing import Any, ClassVar, cast +from unittest.mock import MagicMock, patch + +import pytest + +from core.entities.document_task import DocumentTask +from enums.cloud_plan import CloudPlan +from services.document_indexing_proxy.batch_indexing_base import BatchDocumentIndexingProxy + +# --------------------------------------------------------------------------- +# Concrete subclass for testing (the base class is abstract) +# --------------------------------------------------------------------------- + + +class ConcreteBatchProxy(BatchDocumentIndexingProxy): + """Minimal concrete implementation that provides the required class-level vars.""" + + QUEUE_NAME: ClassVar[str] = "test_queue" + NORMAL_TASK_FUNC: ClassVar[Any] = MagicMock(name="NORMAL_TASK_FUNC") + PRIORITY_TASK_FUNC: ClassVar[Any] = MagicMock(name="PRIORITY_TASK_FUNC") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +DATASET_ID = "dataset-xyz" +DOC_IDS: list[str] = ["doc-1", "doc-2", "doc-3"] + + +def make_proxy(**kwargs: Any) -> ConcreteBatchProxy: + """Factory: returns a ConcreteBatchProxy with TenantIsolatedTaskQueue mocked out.""" + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + proxy = ConcreteBatchProxy( + tenant_id=kwargs.get("tenant_id", TENANT_ID), + dataset_id=kwargs.get("dataset_id", DATASET_ID), + document_ids=kwargs.get("document_ids", DOC_IDS), + ) + # Expose the mock queue on the proxy so tests can assert on it + proxy._tenant_isolated_task_queue = MockQueue.return_value + return proxy + + +# --------------------------------------------------------------------------- +# Test suite +# --------------------------------------------------------------------------- + + +class TestBatchDocumentIndexingProxyInit: + """Tests for __init__ of BatchDocumentIndexingProxy.""" + + def test_should_store_document_ids_when_initialized(self) -> None: + """Verify that document_ids are stored on the proxy instance.""" + # Arrange + doc_ids: list[str] = ["doc-a", "doc-b"] + + # Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert proxy._document_ids == doc_ids + + def test_should_propagate_tenant_and_dataset_to_base_when_initialized(self) -> None: + """Verify that tenant_id and dataset_id are forwarded to the parent class.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + assert proxy._tenant_id == TENANT_ID + assert proxy._dataset_id == DATASET_ID + + def test_should_create_tenant_isolated_queue_with_correct_args_when_initialized(self) -> None: + """Verify that TenantIsolatedTaskQueue is constructed with (tenant_id, QUEUE_NAME).""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue") as MockQueue: + ConcreteBatchProxy(TENANT_ID, DATASET_ID, DOC_IDS) + + # Assert + MockQueue.assert_called_once_with(TENANT_ID, ConcreteBatchProxy.QUEUE_NAME) + + @pytest.mark.parametrize("doc_ids", [[], ["single-doc"], ["d1", "d2", "d3", "d4"]]) + def test_should_accept_any_length_document_ids_when_initialized(self, doc_ids: list[str]) -> None: + """Verify that empty, single, and multiple document IDs are all accepted.""" + # Arrange / Act + with patch("services.document_indexing_proxy.batch_indexing_base.TenantIsolatedTaskQueue"): + proxy = ConcreteBatchProxy(TENANT_ID, DATASET_ID, doc_ids) + + # Assert + assert list(proxy._document_ids) == doc_ids + + +class TestSendToDirectQueue: + """Tests for _send_to_direct_queue.""" + + def test_should_call_task_func_delay_with_correct_args_when_sent_to_direct_queue( + self, + ) -> None: + """Verify that task_func.delay is called with the right kwargs.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + task_func.delay.assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_interact_with_tenant_queue_when_sent_to_direct_queue(self) -> None: + """Direct queue path must never touch the tenant-isolated queue.""" + # Arrange + proxy = make_proxy() + task_func = MagicMock() + + # Act + proxy._send_to_direct_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_forward_any_callable_when_sent_to_direct_queue(self) -> None: + """Verify that different task functions are each called correctly.""" + # Arrange + proxy = make_proxy() + task_a, task_b = MagicMock(), MagicMock() + + # Act + proxy._send_to_direct_queue(task_a) + proxy._send_to_direct_queue(task_b) + + # Assert + task_a.delay.assert_called_once() + task_b.delay.assert_called_once() + + +class TestSendToTenantQueue: + """Tests for _send_to_tenant_queue — both branches.""" + + # ------------------------------------------------------------------ + # Branch 1: get_task_key() is truthy → push to waiting queue + # ------------------------------------------------------------------ + + def test_should_push_task_to_queue_when_task_key_exists(self) -> None: + """When get_task_key() is truthy, tasks must be pushed via push_tasks().""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + expected_payload = [asdict(DocumentTask(tenant_id=TENANT_ID, dataset_id=DATASET_ID, document_ids=DOC_IDS))] + mock_queue.push_tasks.assert_called_once_with(expected_payload) + + def test_should_not_call_task_func_delay_when_task_key_exists(self) -> None: + """When a key already exists, task_func.delay must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + cast(MagicMock, task_func.delay).assert_not_called() + + def test_should_not_set_waiting_time_when_task_key_exists(self) -> None: + """When a key already exists, set_task_waiting_time must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = "existing-key" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_not_called() + + def test_should_serialize_document_task_correctly_when_pushing_to_queue(self) -> None: + """Verify the serialised payload matches asdict(DocumentTask(...)).""" + # Arrange + proxy = make_proxy(document_ids=["doc-x"]) + proxy._tenant_isolated_task_queue.get_task_key.return_value = "k" + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert — inspect the payload passed to push_tasks + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + call_args = mock_queue.push_tasks.call_args + pushed_list = call_args[0][0] # first positional arg + assert len(pushed_list) == 1 + assert pushed_list[0]["tenant_id"] == TENANT_ID + assert pushed_list[0]["dataset_id"] == DATASET_ID + assert pushed_list[0]["document_ids"] == ["doc-x"] + + # ------------------------------------------------------------------ + # Branch 2: get_task_key() is falsy → set flag + dispatch via delay + # ------------------------------------------------------------------ + + def test_should_set_waiting_time_and_call_delay_when_no_task_key(self) -> None: + """When get_task_key() is falsy, set_task_waiting_time and task_func.delay are invoked.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once_with( + tenant_id=TENANT_ID, + dataset_id=DATASET_ID, + document_ids=DOC_IDS, + ) + + def test_should_not_push_tasks_when_no_task_key(self) -> None: + """When get_task_key() is falsy, push_tasks must never be called.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.push_tasks.assert_not_called() + + @pytest.mark.parametrize("falsy_key", [None, "", 0, False]) + def test_should_init_task_when_key_is_any_falsy_value(self, falsy_key: Any) -> None: + """Verify that any falsy return from get_task_key() triggers the init branch.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = falsy_key + task_func = MagicMock() + + # Act + proxy._send_to_tenant_queue(task_func) + + # Assert + mock_queue = cast(MagicMock, proxy._tenant_isolated_task_queue) + mock_queue.set_task_waiting_time.assert_called_once() + cast(MagicMock, task_func.delay).assert_called_once() + + +class TestDispatchRouting: + """Tests for the _dispatch / delay routing logic inherited from the base class.""" + + def _mock_features(self, enabled: bool, plan: CloudPlan) -> MagicMock: + features = MagicMock() + features.billing.enabled = enabled + features.billing.subscription.plan = plan + return features + + def test_should_send_to_normal_tenant_queue_when_billing_enabled_and_sandbox_plan(self) -> None: + """Sandbox plan routes to normal priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + proxy._tenant_isolated_task_queue.get_task_key.return_value = None + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_default_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_tenant_queue_when_billing_enabled_and_paid_plan(self) -> None: + """Non-sandbox paid plan routes to priority queue with tenant isolation.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=True, plan=CloudPlan.PROFESSIONAL) + + # Act + with patch.object(proxy, "_send_to_priority_tenant_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_send_to_priority_direct_queue_when_billing_not_enabled(self) -> None: + """Self-hosted / no billing → priority direct queue (no tenant isolation).""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + + # Act + with patch.object(proxy, "_send_to_priority_direct_queue") as mock_method: + proxy._dispatch() + + # Assert + mock_method.assert_called_once() + + def test_should_call_dispatch_when_delay_is_invoked(self) -> None: + """Calling delay() must invoke _dispatch() exactly once.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_dispatch") as mock_dispatch: + proxy.delay() + + # Assert + mock_dispatch.assert_called_once() + + def test_should_use_feature_service_for_billing_info(self) -> None: + """Verify that FeatureService.get_features is consulted during dispatch.""" + # Arrange + proxy = make_proxy() + + with patch("services.document_indexing_proxy.base.FeatureService.get_features") as mock_features: + mock_features.return_value = self._mock_features(enabled=False, plan=CloudPlan.SANDBOX) + with patch.object(proxy, "_send_to_priority_direct_queue"): + # Act + proxy._dispatch() + + # Assert + mock_features.assert_called_once_with(TENANT_ID) + + +class TestBaseRouterHelpers: + """Tests for the three routing helper methods from the base class.""" + + def test_should_call_send_to_tenant_queue_with_normal_func_when_default_tenant_queue(self) -> None: + """_send_to_default_tenant_queue must forward NORMAL_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_default_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.NORMAL_TASK_FUNC) + + def test_should_call_send_to_tenant_queue_with_priority_func_when_priority_tenant_queue(self) -> None: + """_send_to_priority_tenant_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_tenant_queue") as mock_method: + proxy._send_to_priority_tenant_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) + + def test_should_call_send_to_direct_queue_with_priority_func_when_priority_direct_queue(self) -> None: + """_send_to_priority_direct_queue must forward PRIORITY_TASK_FUNC.""" + # Arrange + proxy = make_proxy() + + # Act + with patch.object(proxy, "_send_to_direct_queue") as mock_method: + proxy._send_to_priority_direct_queue() + + # Assert + mock_method.assert_called_once_with(ConcreteBatchProxy.PRIORITY_TASK_FUNC) diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py new file mode 100644 index 0000000000..105ef7ba48 --- /dev/null +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -0,0 +1,760 @@ +from unittest.mock import MagicMock, patch + +import pytest +from sqlalchemy.orm import Session + +from core.plugin.entities.plugin_daemon import CredentialType +from dify_graph.model_runtime.entities.provider_entities import FormType +from models.account import Account +from models.model import EndUser +from models.oauth import DatasourceProvider +from models.provider_ids import DatasourceProviderID +from services.datasource_provider_service import DatasourceProviderService, get_current_user + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_id(s: str = "org/plugin/provider") -> DatasourceProviderID: + return DatasourceProviderID(s) + + +# --------------------------------------------------------------------------- +# Test class +# --------------------------------------------------------------------------- + + +class TestDatasourceProviderService: + """Comprehensive tests for DatasourceProviderService targeting >95% coverage.""" + + @pytest.fixture + def service(self): + return DatasourceProviderService() + + @pytest.fixture + def mock_db_session(self): + """ + Robust, chainable query mock. + q returns itself for .filter_by(), .order_by(), .where() so any + SQLAlchemy chaining pattern works without multiple brittle sub-mocks. + """ + with patch("services.datasource_provider_service.Session") as mock_cls: + sess = MagicMock(spec=Session) + + q = MagicMock() + sess.query.return_value = q + + # Self-returning chain — any method called on q returns q + q.filter_by.return_value = q + q.order_by.return_value = q + q.where.return_value = q + + # Default terminal values (tests override per-case) + q.first.return_value = None + q.all.return_value = [] + q.count.return_value = 0 + q.delete.return_value = 1 + + mock_cls.return_value.__enter__.return_value = sess + mock_cls.return_value.no_autoflush.__enter__.return_value = sess + + yield sess + + @pytest.fixture(autouse=True) + def patch_db(self, mock_db_session): + with patch("services.datasource_provider_service.db") as mock_db: + mock_db.session = mock_db_session + mock_db.engine = MagicMock() + yield mock_db + + @pytest.fixture(autouse=True) + def patch_externals(self): + with ( + patch("httpx.request") as mock_httpx, + patch("services.datasource_provider_service.dify_config") as mock_cfg, + patch("services.datasource_provider_service.encrypter") as mock_enc, + patch("services.datasource_provider_service.redis_client") as mock_redis, + patch("services.datasource_provider_service.generate_incremental_name") as mock_genname, + patch("services.datasource_provider_service.OAuthHandler") as mock_oauth, + ): + mock_cfg.CONSOLE_API_URL = "http://localhost" + mock_enc.encrypt_token.return_value = "enc_tok" + mock_enc.decrypt_token.return_value = "dec_tok" + mock_enc.decrypt.return_value = {"k": "dec"} + mock_enc.encrypt.return_value = {"k": "enc"} + mock_enc.obfuscated_token.return_value = "obf" + mock_enc.mask_plugin_credentials.return_value = {"k": "mask"} + + mock_redis.lock.return_value.__enter__.return_value = MagicMock() + mock_genname.return_value = "gen_name" + + mock_oauth.return_value.refresh_credentials.return_value = MagicMock( + credentials={"k": "v"}, expires_at=9999 + ) + + resp = MagicMock() + resp.status_code = 200 + resp.json.return_value = { + "code": 0, + "message": "ok", + "data": { + "provider": "prov", + "plugin_unique_identifier": "pui", + "plugin_id": "org/plug", + "is_authorized": False, + "declaration": { + "identity": { + "author": "a", + "name": "n", + "description": {"en_US": "d"}, + "icon": "i", + "label": {"en_US": "l"}, + }, + "credentials_schema": [], + "oauth_schema": {"credentials_schema": [], "client_schema": []}, + "provider_type": "local_file", + "datasources": [], + }, + }, + } + mock_httpx.return_value = resp + + # Store handles for assertions + self._enc = mock_enc + self._redis = mock_redis + yield + + @pytest.fixture + def mock_user(self): + u = MagicMock() + u.id = "uid-1" + return u + + # ----------------------------------------------------------------------- + # get_current_user (lines 27-40) + # ----------------------------------------------------------------------- + + def test_should_return_proxy_when_current_object_is_account(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = Account + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_current_object_is_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + user_obj = MagicMock() + user_obj.__class__ = EndUser + proxy._get_current_object.return_value = user_obj + assert get_current_user() is proxy + + def test_should_return_proxy_when_get_current_object_raises_attribute_error(self): + """AttributeError from LocalProxy falls back to the proxy itself.""" + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.side_effect = AttributeError("no attr") + proxy.__class__ = Account # make the proxy itself satisfy isinstance + assert get_current_user() is proxy + + def test_should_raise_type_error_when_user_is_not_account_or_enduser(self): + with patch("libs.login.current_user", new_callable=MagicMock) as proxy: + proxy._get_current_object.return_value = "plain_string" + with pytest.raises(TypeError, match="current_user must be Account or EndUser"): + get_current_user() + + # ----------------------------------------------------------------------- + # is_system_oauth_params_exist (line 357-363) + # ----------------------------------------------------------------------- + + def test_should_return_true_when_system_oauth_params_exist(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock() + assert service.is_system_oauth_params_exist(make_id()) is True + + def test_should_return_false_when_system_oauth_params_missing(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.is_system_oauth_params_exist(make_id()) is False + + # ----------------------------------------------------------------------- + # is_tenant_oauth_params_enabled (lines 365-379) + # NOTE: uses .count() not .first() + # ----------------------------------------------------------------------- + + def test_should_return_true_when_tenant_oauth_params_enabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 1 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is True + + def test_should_return_false_when_tenant_oauth_params_disabled(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + assert service.is_tenant_oauth_params_enabled("t1", make_id()) is False + + # ----------------------------------------------------------------------- + # remove_oauth_custom_client_params (lines 55-61) + # ----------------------------------------------------------------------- + + def test_should_delete_tenant_config_when_removing_oauth_params(self, service, mock_db_session): + service.remove_oauth_custom_client_params("t1", make_id()) + mock_db_session.query().delete.assert_called_once() + + # ----------------------------------------------------------------------- + # setup_oauth_custom_client_params (315-351) + # ----------------------------------------------------------------------- + + def test_should_skip_db_write_when_credentials_are_none(self, service, mock_db_session): + """When credentials=None, should return immediately without any DB write.""" + service.setup_oauth_custom_client_params("t1", make_id(), None, None) + mock_db_session.add.assert_not_called() + + def test_should_create_new_config_when_none_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, True) + mock_db_session.add.assert_called_once() + + def test_should_update_existing_config_when_record_found(self, service, mock_db_session): + existing = MagicMock() + mock_db_session.query().first.return_value = existing + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + service.setup_oauth_custom_client_params("t1", make_id(), {"k": "v"}, False) + mock_db_session.add.assert_not_called() # update in place, no add + + # ----------------------------------------------------------------------- + # decrypt / encrypt credentials (lines 70-98) + # ----------------------------------------------------------------------- + + def test_should_decrypt_secret_fields_when_decrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "enc_val"} + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.decrypt_datasource_provider_credentials("t1", p, "org/plug", "prov") + assert result["sk"] == "dec_tok" + + def test_should_encrypt_secret_fields_when_encrypting_api_key_credentials(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.encrypt_datasource_provider_credentials("t1", "prov", "org/plug", {"sk": "plain"}, p) + assert result["sk"] == "enc_tok" + self._enc.encrypt_token.assert_called() + + # ----------------------------------------------------------------------- + # get_datasource_credentials (lines 113-165) + # ----------------------------------------------------------------------- + + def test_should_return_empty_dict_when_credential_not_found(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().first.return_value = None + assert service.get_datasource_credentials("t1", "prov", "org/plug") == {} + + def test_should_refresh_oauth_tokens_when_expired(self, service, mock_db_session, mock_user): + """Expired OAuth credential (expires_at near zero) triggers a silent refresh.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 # expired + p.encrypted_credentials = {"tok": "x"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"tok": "plain"}), + ): + service.get_datasource_credentials("t1", "prov", "org/plug") + mock_db_session.commit.assert_called_once() + + def test_should_return_decrypted_credentials_when_api_key_not_expired(self, service, mock_db_session, mock_user): + """API key credentials with expires_at=-1 skip refresh and return directly.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 # sentinel: never expires + p.encrypted_credentials = {"k": "v"} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "plain"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug") + assert result == {"k": "plain"} + + def test_should_fetch_by_credential_id_when_provided(self, service, mock_db_session, mock_user): + """When credential_id is passed, the credential_id filter path (line 113) is taken.""" + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.expires_at = -1 + p.encrypted_credentials = {} + mock_db_session.query().first.return_value = p + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"k": "v"}), + ): + result = service.get_datasource_credentials("t1", "prov", "org/plug", credential_id="cred-id") + assert result == {"k": "v"} + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials_by_provider (lines 176-228) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_provider_credentials_exist(self, service, mock_db_session, mock_user): + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + mock_db_session.query().all.return_value = [] + assert service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") == [] + + def test_should_refresh_and_return_credentials_when_oauth_expired(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "oauth2" + p.expires_at = 0 + p.encrypted_credentials = {"t": "x"} + mock_db_session.query().all.return_value = [p] + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "get_oauth_client", return_value={"oc": "v"}), + patch.object(service, "decrypt_datasource_provider_credentials", return_value={"t": "plain"}), + ): + result = service.get_all_datasource_credentials_by_provider("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_provider_name (lines 236-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_provider_not_found_on_name_update(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.update_datasource_provider_name("t1", make_id(), "new", "cred-id") + + def test_should_return_early_when_new_name_matches_current(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "same" + mock_db_session.query().first.return_value = p + service.update_datasource_provider_name("t1", make_id(), "same", "cred-id") + mock_db_session.commit.assert_not_called() + + def test_should_raise_value_error_when_name_already_exists(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + + def test_should_update_name_and_commit_when_no_conflict(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.is_default = False + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + service.update_datasource_provider_name("t1", make_id(), "new_name", "some-id") + assert p.name == "new_name" + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # set_default_datasource_provider (lines 277-303) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_target_provider_not_found(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with pytest.raises(ValueError, match="not found"): + service.set_default_datasource_provider("t1", make_id(), "bad-id") + + def test_should_mark_target_as_default_and_commit(self, service, mock_db_session): + target = MagicMock(spec=DatasourceProvider) + target.provider = "provider" + target.plugin_id = "org/plug" + mock_db_session.query().first.return_value = target + service.set_default_datasource_provider("t1", make_id(), "new-id") + assert target.is_default is True + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # get_oauth_encrypter (lines 404-420) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_oauth_schema_missing(self, service): + pm = MagicMock() + pm.declaration.oauth_schema = None + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="oauth schema not found"): + service.get_oauth_encrypter("t1", make_id()) + + def test_should_return_encrypter_when_oauth_schema_exists(self, service): + schema_item = MagicMock() + schema_item.to_basic_provider_config.return_value = MagicMock() + pm = MagicMock() + pm.declaration.oauth_schema.client_schema = [schema_item] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm), + patch( + "services.datasource_provider_service.create_provider_encrypter", + return_value=(MagicMock(), MagicMock()), + ), + ): + result = service.get_oauth_encrypter("t1", make_id()) + assert result is not None + + # ----------------------------------------------------------------------- + # get_tenant_oauth_client (lines 381-402) + # ----------------------------------------------------------------------- + + def test_should_return_masked_credentials_when_mask_is_true(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=True) + assert result == {"k": "mask"} + + def test_should_return_decrypted_credentials_when_mask_is_false(self, service, mock_db_session): + tenant_params = MagicMock() + tenant_params.client_params = {"k": "v"} + mock_db_session.query().first.return_value = tenant_params + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_tenant_oauth_client("t1", make_id(), mask=False) + assert result == {"k": "dec"} + + def test_should_return_none_when_no_tenant_oauth_config_exists(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + assert service.get_tenant_oauth_client("t1", make_id()) is None + + # ----------------------------------------------------------------------- + # get_oauth_client (lines 423-457) + # ----------------------------------------------------------------------- + + def test_should_use_tenant_config_when_available(self, service, mock_db_session): + mock_db_session.query().first.return_value = MagicMock(client_params={"k": "v"}) + with patch.object(service, "get_oauth_encrypter", return_value=(self._enc, None)): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "dec"} + + def test_should_fallback_to_system_credentials_when_tenant_config_missing(self, service, mock_db_session): + mock_db_session.query().first.side_effect = [None, MagicMock(system_credentials={"k": "sys"})] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=True), + ): + result = service.get_oauth_client("t1", make_id()) + assert result == {"k": "sys"} + + def test_should_raise_value_error_when_no_oauth_config_available(self, service, mock_db_session): + """Neither tenant nor system credentials → raises ValueError.""" + mock_db_session.query().first.side_effect = [None, None] + with ( + patch.object(service.provider_manager, "fetch_datasource_provider"), + patch("services.datasource_provider_service.PluginService.is_plugin_verified", return_value=False), + ): + with pytest.raises(ValueError, match="Please configure oauth client params"): + service.get_oauth_client("t1", make_id()) + + # ----------------------------------------------------------------------- + # add_datasource_oauth_provider (lines 539-607) + # ----------------------------------------------------------------------- + + def test_should_add_oauth_provider_successfully_when_name_is_unique(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("new", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_oauth_provider_name_conflicts(self, service, mock_db_session): + """Conflict on name results in auto-incremented name, not an error.""" + mock_db_session.query().count.return_value = 1 # conflict first, then auto-named + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="new_gen"), + ): + service.add_datasource_oauth_provider("conflict", "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_auto_generate_name_when_none_provided_for_oauth(self, service, mock_db_session): + """name=None causes auto-generation via generate_next_datasource_provider_name.""" + mock_db_session.query().count.return_value = 0 + mock_db_session.query().all.return_value = [] + with ( + patch.object(service, "extract_secret_variables", return_value=[]), + patch.object(service, "generate_next_datasource_provider_name", return_value="auto"), + ): + service.add_datasource_oauth_provider(None, "t1", make_id(), "http://cb", 9999, {}) + mock_db_session.add.assert_called_once() + + def test_should_encrypt_secret_fields_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["secret_key"]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {"secret_key": "value"}) + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_adding_oauth_provider(self, service, mock_db_session): + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.add_datasource_oauth_provider("nm", "t1", make_id(), "http://cb", 9999, {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # reauthorize_datasource_oauth_provider (lines 477-537) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_id_not_found_on_reauth(self, service, mock_db_session): + mock_db_session.query().first.return_value = None + with patch.object(service, "extract_secret_variables", return_value=[]): + with pytest.raises(ValueError, match="not found"): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "bad-id") + + def test_should_reauthorize_and_commit_when_credential_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + mock_db_session.commit.assert_called_once() + + def test_should_auto_rename_when_reauth_name_conflicts(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 # conflict + mock_db_session.query().all.return_value = [] + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider( + "conflict_name", "t1", make_id(), "u", 9999, {"tok": "v"}, "cred-id" + ) + mock_db_session.commit.assert_called_once() + + def test_should_encrypt_secret_fields_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=["tok"]): + service.reauthorize_datasource_oauth_provider(None, "t1", make_id(), "u", 9999, {"tok": "val"}, "cred-id") + self._enc.encrypt_token.assert_called() + + def test_should_acquire_redis_lock_when_reauthorizing(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with patch.object(service, "extract_secret_variables", return_value=[]): + service.reauthorize_datasource_oauth_provider("n", "t1", make_id(), "u", 1, {}, "oid") + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # add_datasource_api_key_provider (lines 608-675) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_api_key_name_already_exists(self, service, mock_db_session, mock_user): + """explicit name supplied + conflict → raises ValueError immediately.""" + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.add_datasource_api_key_provider("clash", "t1", make_id(), {"sk": "v"}) + + def test_should_raise_value_error_when_credentials_validation_fails(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad cred")), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.add_datasource_api_key_provider("nm", "t1", make_id(), {"k": "v"}) + + def test_should_add_api_key_provider_and_commit_when_valid(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {"sk": "v"}) + mock_db_session.add.assert_called_once() + mock_db_session.commit.assert_called_once() + + def test_should_acquire_redis_lock_when_adding_api_key_provider(self, service, mock_db_session, mock_user): + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service.provider_manager, "validate_provider_credentials"), + patch.object(service, "extract_secret_variables", return_value=[]), + ): + service.add_datasource_api_key_provider(None, "t1", make_id(), {}) + self._redis.lock.assert_called() + + # ----------------------------------------------------------------------- + # extract_secret_variables (lines 666-699) + # ----------------------------------------------------------------------- + + def test_should_extract_secret_variable_names_for_api_key_schema(self, service): + schema = MagicMock() + schema.name = "my_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT # "secret-input" + pm = MagicMock() + pm.declaration.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.API_KEY) + assert "my_secret" in result + + def test_should_extract_secret_variable_names_for_oauth2_schema(self, service): + schema = MagicMock() + schema.name = "oauth_secret" + schema.type = MagicMock() + schema.type.value = FormType.SECRET_INPUT + pm = MagicMock() + pm.declaration.oauth_schema.credentials_schema = [schema] + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + result = service.extract_secret_variables("t1", "org/plug/prov", CredentialType.OAUTH2) + assert "oauth_secret" in result + + def test_should_raise_value_error_when_credential_type_is_invalid(self, service): + pm = MagicMock() + with patch.object(service.provider_manager, "fetch_datasource_provider", return_value=pm): + with pytest.raises(ValueError, match="Invalid credential type"): + service.extract_secret_variables("t1", "org/plug/prov", CredentialType.UNAUTHORIZED) + + # ----------------------------------------------------------------------- + # list_datasource_credentials (lines 721-754) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_credentials_stored(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.list_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_masked_credentials_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + p.is_default = False + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.list_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # get_all_datasource_credentials (lines 808-871) + # ----------------------------------------------------------------------- + + def test_should_aggregate_credentials_for_non_hardcoded_plugin(self, service): + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.provider = "prov" + ds.plugin_id = "org/plug" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Label"} + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + cred = {"credential": {"k": "v"}, "is_default": True} + with patch.object(service, "list_datasource_credentials", return_value=[cred]): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + + def test_should_include_oauth_schema_for_hardcoded_plugin_ids(self, service, mock_db_session): + """Lines 819-871: get_all_datasource_credentials covers hardcoded langgenius plugin IDs.""" + with patch("services.datasource_provider_service.PluginDatasourceManager") as mock_mgr: + ds = MagicMock() + ds.plugin_id = "langgenius/firecrawl_datasource" + ds.provider = "firecrawl" + ds.plugin_unique_identifier = "pui" + ds.declaration.identity.icon = "icon" + ds.declaration.identity.name = "langgenius/firecrawl_datasource" + ds.declaration.identity.label.model_dump.return_value = {"en_US": "Firecrawl"} + ds.declaration.identity.description.model_dump.return_value = {"en_US": "desc"} + ds.declaration.identity.author = "langgenius" + ds.declaration.credentials_schema = [] + ds.declaration.oauth_schema.client_schema = [] + ds.declaration.oauth_schema.credentials_schema = [] + mock_mgr.return_value.fetch_installed_datasource_providers.return_value = [ds] + with ( + patch.object(service, "list_datasource_credentials", return_value=[]), + patch.object(service, "get_tenant_oauth_client", return_value=None), + patch.object(service, "is_tenant_oauth_params_enabled", return_value=False), + patch.object(service, "is_system_oauth_params_exist", return_value=False), + ): + results = service.get_all_datasource_credentials("t1") + assert len(results) == 1 + assert results[0]["oauth_schema"] is not None + + # ----------------------------------------------------------------------- + # get_real_datasource_credentials (lines 873-915) + # ----------------------------------------------------------------------- + + def test_should_return_empty_list_when_no_real_credentials_exist(self, service, mock_db_session): + mock_db_session.query().all.return_value = [] + assert service.get_real_datasource_credentials("t1", "prov", "org/plug") == [] + + def test_should_return_decrypted_credential_list_when_credentials_exist(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "v"} + mock_db_session.query().all.return_value = [p] + with patch.object(service, "extract_secret_variables", return_value=["sk"]): + result = service.get_real_datasource_credentials("t1", "prov", "org/plug") + assert len(result) == 1 + + # ----------------------------------------------------------------------- + # update_datasource_credentials (lines 917-978) + # ----------------------------------------------------------------------- + + def test_should_raise_value_error_when_credential_not_found_on_update(self, service, mock_db_session, mock_user): + mock_db_session.query().first.return_value = None + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="not found"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "name") + + def test_should_raise_value_error_when_new_name_already_used_on_update(self, service, mock_db_session, mock_user): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 1 + with patch("services.datasource_provider_service.get_current_user", return_value=mock_user): + with pytest.raises(ValueError, match="already exists"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {}, "new_name") + + def test_should_raise_value_error_when_credential_validation_fails_on_update( + self, service, mock_db_session, mock_user + ): + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "e"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials", side_effect=Exception("bad")), + ): + with pytest.raises(ValueError, match="Failed to validate"): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "v"}, "name") + + def test_should_encrypt_credentials_and_commit_when_update_succeeds(self, service, mock_db_session, mock_user): + """Verifies that encrypted_credentials is reassigned with encrypted value and commit is called.""" + p = MagicMock(spec=DatasourceProvider) + p.name = "old_name" + p.auth_type = "api_key" + p.encrypted_credentials = {"sk": "old_enc"} + mock_db_session.query().first.return_value = p + mock_db_session.query().count.return_value = 0 + with ( + patch("services.datasource_provider_service.get_current_user", return_value=mock_user), + patch.object(service, "extract_secret_variables", return_value=["sk"]), + patch.object(service.provider_manager, "validate_provider_credentials"), + ): + service.update_datasource_credentials("t1", "id", "prov", "org/plug", {"sk": "new_val"}, "name") + # encrypter must have been called with the new secret value + self._enc.encrypt_token.assert_called() + # commit must be called exactly once + mock_db_session.commit.assert_called_once() + + # ----------------------------------------------------------------------- + # remove_datasource_credentials (lines 980-997) + # ----------------------------------------------------------------------- + + def test_should_delete_provider_and_commit_when_found(self, service, mock_db_session): + p = MagicMock(spec=DatasourceProvider) + mock_db_session.query().first.return_value = p + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_called_once_with(p) + mock_db_session.commit.assert_called_once() + + def test_should_do_nothing_when_credential_not_found_on_remove(self, service, mock_db_session): + """No error raised; no delete called when record doesn't exist (lines 994 branch).""" + mock_db_session.query().first.return_value = None + service.remove_datasource_credentials("t1", "id", "prov", "org/plug") + mock_db_session.delete.assert_not_called() diff --git a/api/tests/unit_tests/services/test_hit_testing_service.py b/api/tests/unit_tests/services/test_hit_testing_service.py new file mode 100644 index 0000000000..80e9729f5b --- /dev/null +++ b/api/tests/unit_tests/services/test_hit_testing_service.py @@ -0,0 +1,385 @@ +import json +from typing import Any, cast +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from core.rag.models.document import Document +from models.dataset import Dataset +from services.hit_testing_service import HitTestingService + + +class TestHitTestingService: + """Test suite for HitTestingService""" + + # ===== Utility Method Tests ===== + + def test_escape_query_for_search_should_escape_double_quotes(self): + """Test that escape_query_for_search escapes double quotes correctly""" + # Arrange + query = 'test "query" with quotes' + expected = 'test \\"query\\" with quotes' + + # Act + result = HitTestingService.escape_query_for_search(query) + + # Assert + assert result == expected + + def test_hit_testing_args_check_should_pass_with_valid_query(self): + """Test that hit_testing_args_check passes with a valid query""" + # Arrange + args = {"query": "valid query"} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_pass_with_valid_attachments(self): + """Test that hit_testing_args_check passes with valid attachment_ids""" + # Arrange + args = {"attachment_ids": ["id1", "id2"]} + + # Act & Assert (should not raise) + HitTestingService.hit_testing_args_check(args) + + def test_hit_testing_args_check_should_raise_error_when_no_query_or_attachments(self): + """Test that hit_testing_args_check raises ValueError if both query and attachment_ids are missing""" + # Arrange + args = {} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query or attachment_ids is required" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_query_too_long(self): + """Test that hit_testing_args_check raises ValueError if query exceeds 250 characters""" + # Arrange + args = {"query": "a" * 251} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Query cannot exceed 250 characters" in str(exc_info.value) + + def test_hit_testing_args_check_should_raise_error_when_attachments_not_list(self): + """Test that hit_testing_args_check raises ValueError if attachment_ids is not a list""" + # Arrange + args = {"attachment_ids": "not a list"} + + # Act & Assert + with pytest.raises(ValueError) as exc_info: + HitTestingService.hit_testing_args_check(args) + assert "Attachment_ids must be a list" in str(exc_info.value) + + # ===== Response Formatting Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.format_retrieval_documents") + def test_compact_retrieve_response_should_format_correctly(self, mock_format): + """Test that compact_retrieve_response formats the response correctly""" + # Arrange + query = "test query" + mock_doc = MagicMock(spec=Document) + documents = [mock_doc] + + mock_record = MagicMock() + mock_record.model_dump.return_value = {"content": "formatted content"} + mock_format.return_value = [mock_record] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_retrieve_response(query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 1 + assert cast(dict[str, Any], result["records"][0])["content"] == "formatted content" + mock_format.assert_called_once_with(documents) + + def test_compact_external_retrieve_response_should_return_records_for_external_provider(self): + """Test that compact_external_retrieve_response returns records when dataset provider is external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "external" + query = "test query" + documents = [ + {"content": "c1", "title": "t1", "score": 0.9, "metadata": {"m1": "v1"}}, + {"content": "c2", "title": "t2", "score": 0.8, "metadata": {"m2": "v2"}}, + ] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert len(result["records"]) == 2 + assert cast(dict[str, Any], result["records"][0])["content"] == "c1" + assert cast(dict[str, Any], result["records"][1])["title"] == "t2" + + def test_compact_external_retrieve_response_should_return_empty_for_non_external_provider(self): + """Test that compact_external_retrieve_response returns empty records for non-external provider""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + documents = [{"content": "c1"}] + + # Act + result = cast(dict[str, Any], HitTestingService.compact_external_retrieve_response(dataset, query, documents)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== External Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.external_retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_external_retrieve_should_succeed_for_external_provider(self, mock_commit, mock_add, mock_ext_retrieve): + """Test that external_retrieve successfully retrieves from external provider and commits query""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.provider = "external" + query = 'test "query"' + account = MagicMock() + account.id = "account_id" + + mock_ext_retrieve.return_value = [{"content": "ext content", "score": 1.0}] + + # Act + result = cast( + dict[str, Any], + HitTestingService.external_retrieve( + dataset=dataset, + query=query, + account=account, + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert cast(dict[str, Any], result["records"][0])["content"] == "ext content" + + # Verify call to RetrievalService.external_retrieve with escaped query + mock_ext_retrieve.assert_called_once_with( + dataset_id="dataset_id", + query='test \\"query\\"', + external_retrieval_model={"model": "test"}, + metadata_filtering_conditions={"key": "val"}, + ) + + # Verify DatasetQuery record was added and committed + mock_add.assert_called_once() + mock_commit.assert_called_once() + + def test_external_retrieve_should_return_empty_for_non_external_provider(self): + """Test that external_retrieve returns empty results immediately if provider is not external""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.provider = "not_external" + query = "test query" + account = MagicMock() + + # Act + result = cast(dict[str, Any], HitTestingService.external_retrieve(dataset, query, account)) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + assert result["records"] == [] + + # ===== Retrieve Tests ===== + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_use_default_model_when_none_provided(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve uses default model when retrieval_model is not provided""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + dataset.retrieval_model = None + query = "test query" + account = MagicMock() + account.id = "account_id" + + mock_retrieve.return_value = [] + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=None, external_retrieval_model={} + ), + ) + + # Assert + assert cast(dict[str, Any], result["query"])["content"] == query + mock_retrieve.assert_called_once() + # Verify top_k from default_retrieval_model (4) + assert mock_retrieve.call_args.kwargs["top_k"] == 4 + mock_commit.assert_called_once() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_metadata_filtering(self, mock_commit, mock_add, mock_get_meta, mock_retrieve): + """Test that retrieve correctly calls metadata filtering when conditions are present""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response + mock_get_meta.return_value = ({"dataset_id": ["doc_id1"]}, "condition_string") + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_get_meta.assert_called_once() + mock_retrieve.assert_called_once() + assert mock_retrieve.call_args.kwargs["document_ids_filter"] == ["doc_id1"] + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("core.rag.retrieval.dataset_retrieval.DatasetRetrieval.get_metadata_filter_condition") + def test_retrieve_should_return_empty_if_metadata_filtering_fails(self, mock_get_meta, mock_retrieve): + """Test that retrieve returns empty response if metadata filtering returns condition but no document IDs""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + + retrieval_model = { + "search_method": "semantic_search", + "metadata_filtering_conditions": {"some": "condition"}, + "top_k": 5, + "reranking_enable": False, + "score_threshold_enabled": False, + } + + # Mock metadata filtering response: condition returned but no IDs + mock_get_meta.return_value = ({}, "condition_string") + + # Act + result = cast( + dict[str, Any], + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + ), + ) + + # Assert + assert result["records"] == [] + mock_retrieve.assert_not_called() + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_attachments(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve handles attachment_ids and adds them to DatasetQuery""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + attachment_ids = ["att1", "att2"] + + retrieval_model = { + "search_method": "semantic_search", + "top_k": 4, + "reranking_enable": False, + "score_threshold_enabled": False, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, + query=query, + account=account, + retrieval_model=retrieval_model, + external_retrieval_model={}, + attachment_ids=attachment_ids, + ) + + # Assert + mock_retrieve.assert_called_once_with( + retrieval_method=ANY, + dataset_id="dataset_id", + query=query, + attachment_ids=attachment_ids, + top_k=4, + score_threshold=0.0, + reranking_model=None, + reranking_mode="reranking_model", + weights=None, + document_ids_filter=None, + ) + # Verify DatasetQuery record (there should be 2 queries: 1 text, 2 images) + # The content is json.dumps([{"content_type": "text_query", ...}, {"content_type": "image_query", ...}]) + called_query = mock_add.call_args[0][0] + query_content = json.loads(called_query.content) + assert len(query_content) == 3 # 1 text + 2 images + assert query_content[0]["content_type"] == "text_query" + assert query_content[1]["content_type"] == "image_query" + assert query_content[1]["content"] == "att1" + + @patch("core.rag.datasource.retrieval_service.RetrievalService.retrieve") + @patch("extensions.ext_database.db.session.add") + @patch("extensions.ext_database.db.session.commit") + def test_retrieve_should_handle_reranking_and_threshold(self, mock_commit, mock_add, mock_retrieve): + """Test that retrieve passes reranking and threshold parameters correctly""" + # Arrange + dataset = MagicMock(spec=Dataset) + dataset.id = "dataset_id" + query = "test query" + account = MagicMock() + account.id = "account_id" + + retrieval_model = { + "search_method": "hybrid_search", + "top_k": 10, + "reranking_enable": True, + "reranking_model": {"provider": "test"}, + "reranking_mode": "weighted_sum", + "score_threshold_enabled": True, + "score_threshold": 0.5, + "weights": {"vector": 0.5, "keyword": 0.5}, + } + mock_retrieve.return_value = [] + + # Act + HitTestingService.retrieve( + dataset=dataset, query=query, account=account, retrieval_model=retrieval_model, external_retrieval_model={} + ) + + # Assert + mock_retrieve.assert_called_once() + kwargs = mock_retrieve.call_args.kwargs + assert kwargs["score_threshold"] == 0.5 + assert kwargs["reranking_model"] == {"provider": "test"} + assert kwargs["reranking_mode"] == "weighted_sum" + assert kwargs["weights"] == {"vector": 0.5, "keyword": 0.5} diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py new file mode 100644 index 0000000000..bc0caee071 --- /dev/null +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -0,0 +1,146 @@ +from typing import Any, cast +from unittest.mock import MagicMock, patch + +import pytest + +from services.knowledge_service import ExternalDatasetTestService + + +class TestKnowledgeService: + """Test suite for ExternalDatasetTestService""" + + # ===== Happy Path Tests ===== + + @patch("services.knowledge_service.boto3.client") + @patch("services.knowledge_service.dify_config") + def test_knowledge_retrieval_should_succeed_with_valid_results( + self, mock_dify_config: MagicMock, mock_boto_client: MagicMock + ): + """Test that knowledge_retrieval successfully parses results from Bedrock""" + # Arrange + mock_dify_config.AWS_SECRET_ACCESS_KEY = "dummy_secret" + mock_dify_config.AWS_ACCESS_KEY_ID = "dummy_id" + + mock_client = MagicMock() + mock_boto_client.return_value = mock_client + + retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + query = "test query" + knowledge_id = "kb-123" + + # Mock successful response + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.9, + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc1.pdf"}, + "content": {"text": "content from doc1"}, + }, + { + "score": 0.4, # Below threshold + "metadata": {"x-amz-bedrock-kb-source-uri": "s3://bucket/doc2.pdf"}, + "content": {"text": "content from doc2"}, + }, + ], + } + + # Act + result = cast( + dict[str, Any], ExternalDatasetTestService.knowledge_retrieval(retrieval_setting, query, knowledge_id) + ) + + # Assert + assert len(result["records"]) == 1 + record = result["records"][0] + assert record["score"] == 0.9 + assert record["title"] == "s3://bucket/doc1.pdf" + assert record["content"] == "content from doc1" + + # verify retrieve called correctly + mock_client.retrieve.assert_called_once_with( + knowledgeBaseId=knowledge_id, + retrievalConfiguration={ + "vectorSearchConfiguration": { + "numberOfResults": 4, + "overrideSearchType": "HYBRID", + } + }, + retrievalQuery={"text": query}, + ) + + # NEW: verify boto3.client created with proper service name and config values + mock_boto_client.assert_called_once_with( + "bedrock-agent-runtime", + aws_secret_access_key="dummy_secret", + aws_access_key_id="dummy_id", + region_name="us-east-1", + ) + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_when_no_results(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records when Bedrock returns nothing""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + # ===== Error Handling Tests ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_return_empty_on_http_error(self, mock_boto: MagicMock): + """Test that knowledge_retrieval returns empty records if Bedrock returns non-200 status""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} + + # Act + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert result["records"] == [] + + def test_knowledge_retrieval_should_raise_when_boto_client_creation_fails(self): + """Test that exceptions from boto3.client propagate (e.g., network/credentials issues)""" + with patch("services.knowledge_service.boto3.client") as mock_boto: + mock_boto.side_effect = Exception("client init failed") + with pytest.raises(Exception) as exc_info: + ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + assert "client init failed" in str(exc_info.value) + + # ===== Edge Cases ===== + + @patch("services.knowledge_service.boto3.client") + def test_knowledge_retrieval_should_handle_missing_threshold_in_settings(self, mock_boto: MagicMock): + """Test that knowledge_retrieval uses 0.0 as default threshold if not provided""" + # Arrange + mock_client = MagicMock() + mock_boto.return_value = mock_client + + mock_client.retrieve.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "retrievalResults": [ + { + "score": 0.1, + "metadata": {"x-amz-bedrock-kb-source-uri": "uri"}, + "content": {"text": "text"}, + } + ], + } + + # Act + # retrieval_setting missing "score_threshold" + result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + + # Assert + assert len(result["records"]) == 1 + assert result["records"][0]["score"] == 0.1 diff --git a/api/tests/unit_tests/services/test_operation_service.py b/api/tests/unit_tests/services/test_operation_service.py new file mode 100644 index 0000000000..a4c69b23ac --- /dev/null +++ b/api/tests/unit_tests/services/test_operation_service.py @@ -0,0 +1,120 @@ +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +from services.operation_service import OperationService + + +class TestOperationService: + """Test suite for OperationService""" + + # ===== Internal Method Tests ===== + + @patch("httpx.request") + def test_should_call_with_correct_parameters_when__send_request_invoked( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request calls httpx.request with the correct URL, headers, and data""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + monkeypatch.setattr(OperationService, "secret_key", "s3cr3t") + + mock_response = MagicMock() + mock_response.json.return_value = {"status": "success"} + mock_request.return_value = mock_response + + method = "POST" + endpoint = "/test_endpoint" + json_data = {"key": "value"} + + # Act + result = OperationService._send_request(method, endpoint, json=json_data) + + # Assert + assert result == {"status": "success"} + + # Verify call parameters + expected_url = "https://billing.example/test_endpoint" + mock_request.assert_called_once() + args, kwargs = mock_request.call_args + assert args[0] == method + assert args[1] == expected_url + assert kwargs["json"] == json_data + assert kwargs["headers"]["Billing-Api-Secret-Key"] == "s3cr3t" + assert kwargs["headers"]["Content-Type"] == "application/json" + + @patch("httpx.request") + def test_should_propagate_httpx_error_when__send_request_raises( + self, mock_request: MagicMock, monkeypatch: pytest.MonkeyPatch + ): + """Test that _send_request handles httpx raising an error""" + # Arrange + monkeypatch.setattr(OperationService, "base_url", "https://billing.example") + mock_request.side_effect = httpx.RequestError("network error") + + # Act & Assert + with pytest.raises(httpx.RequestError): + OperationService._send_request("POST", "/test") + + # ===== Public Method Tests ===== + + @pytest.mark.parametrize( + ("utm_info", "expected_params"), + [ + ( + { + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + { + "tenant_id": "tenant-123", + "utm_source": "google", + "utm_medium": "cpc", + "utm_campaign": "spring_sale", + "utm_content": "ad_1", + "utm_term": "ai_agent", + }, + ), + ( + {}, # Empty utm_info + { + "tenant_id": "tenant-123", + "utm_source": "", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ( + {"utm_source": "newsletter"}, # Partial utm_info + { + "tenant_id": "tenant-123", + "utm_source": "newsletter", + "utm_medium": "", + "utm_campaign": "", + "utm_content": "", + "utm_term": "", + }, + ), + ], + ) + @patch.object(OperationService, "_send_request") + def test_should_map_parameters_correctly_when_record_utm_called( + self, mock_send: MagicMock, utm_info: dict, expected_params: dict + ): + """Test that record_utm correctly maps utm_info to parameters and calls _send_request""" + # Arrange + tenant_id = "tenant-123" + mock_send.return_value = {"status": "recorded"} + + # Act + result = OperationService.record_utm(tenant_id, utm_info) + + # Assert + assert result == {"status": "recorded"} + mock_send.assert_called_once_with("POST", "/tenant_utms", params=expected_params)