diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py index 388ba8a840..deb26438a8 100644 --- a/api/services/tools/mcp_tools_manage_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -704,7 +704,7 @@ class MCPToolManageService: raise ValueError(f"MCP tool {server_url} already exists") if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - raise + raise error def _is_valid_url(self, url: str) -> bool: """Validate URL format.""" diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py new file mode 100644 index 0000000000..bbdc16d4f8 --- /dev/null +++ b/api/tests/unit_tests/services/test_metadata_service.py @@ -0,0 +1,558 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource +from models.dataset import Dataset +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataArgs, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +@dataclass +class _DocumentStub: + id: str + name: str + uploader: str + upload_date: datetime + last_update_date: datetime + data_source_type: str + doc_metadata: dict[str, object] | None + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + mocked_db = mocker.patch("services.metadata_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.fixture +def mock_redis_client(mocker: MockerFixture) -> MagicMock: + return mocker.patch("services.metadata_service.redis_client") + + +@pytest.fixture +def mock_current_account(mocker: MockerFixture) -> MagicMock: + mock_user = SimpleNamespace(id="user-1") + return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) + + +def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: + now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) + return _DocumentStub( + id=document_id, + name=f"doc-{document_id}", + uploader="qa@example.com", + upload_date=now, + last_update_date=now, + data_source_type="upload_file", + doc_metadata=doc_metadata, + ) + + +def _dataset(**kwargs: Any) -> Dataset: + return cast(Dataset, SimpleNamespace(**kwargs)) + + +def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="x" * 256) + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name="priority") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + mock_current_account.assert_called_once() + + +def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.create_metadata("dataset-1", metadata_args) + + +def test_create_metadata_should_persist_metadata_when_input_is_valid( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + metadata_args = MetadataArgs(type="number", name="score") + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = MetadataService.create_metadata("dataset-1", metadata_args) + + # Assert + assert result.tenant_id == "tenant-1" + assert result.dataset_id == "dataset-1" + assert result.type == "number" + assert result.name == "score" + assert result.created_by == "user-1" + mock_db.session.add.assert_called_once_with(result) + mock_db.session.commit.assert_called_once() + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: + # Arrange + too_long_name = "x" * 256 + + # Act + Assert + with pytest.raises(ValueError, match="cannot exceed 255"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) + + +def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( + mock_db: MagicMock, mock_current_account: MagicMock +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( + mock_db: MagicMock, + mock_current_account: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Built-in fields"): + MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) + + # Assert + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_update_bound_documents_and_return_metadata( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) + mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) + + metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) + bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] + + doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) + doc_2 = _build_document("2", None) + mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") + mock_get_documents.return_value = [doc_1, doc_2] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") + + # Assert + assert result is metadata + assert metadata.name == "new_name" + assert metadata.updated_by == "user-1" + assert metadata.updated_at == fixed_now + assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} + assert doc_2.doc_metadata == {"new_name": None} + mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + query_duplicate = MagicMock() + query_duplicate.filter_by.return_value.first.return_value = None + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = None + mock_db.session.query.side_effect = [query_duplicate, query_metadata] + + # Act + result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + mock_current_account.assert_called_once() + + +def test_delete_metadata_should_remove_metadata_and_related_document_fields( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + metadata = SimpleNamespace(id="metadata-1", name="obsolete") + bindings = [SimpleNamespace(document_id="doc-1")] + query_metadata = MagicMock() + query_metadata.filter_by.return_value.first.return_value = metadata + query_bindings = MagicMock() + query_bindings.filter_by.return_value.all.return_value = bindings + mock_db.session.query.side_effect = [query_metadata, query_bindings] + + document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) + + # Act + result = MetadataService.delete_metadata("dataset-1", "metadata-1") + + # Assert + assert result is metadata + assert document.doc_metadata == {"remaining": "value"} + mock_db.session.delete.assert_called_once_with(metadata) + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_delete_metadata_should_return_none_when_metadata_is_missing( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + mock_logger = mocker.patch("services.metadata_service.logger") + + # Act + result = MetadataService.delete_metadata("dataset-1", "missing-id") + + # Assert + assert result is None + mock_logger.exception.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_get_built_in_fields_should_return_all_expected_fields() -> None: + # Arrange + expected_names = { + BuiltInField.document_name, + BuiltInField.uploader, + BuiltInField.upload_date, + BuiltInField.last_update_date, + BuiltInField.source, + } + + # Act + result = MetadataService.get_built_in_fields() + + # Assert + assert {item["name"] for item in result} == expected_names + assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] + + +def test_enable_built_in_field_should_return_immediately_when_already_enabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_enable_built_in_field_should_populate_documents_and_enable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + doc_1 = _build_document("1", {"custom": "value"}) + doc_2 = _build_document("2", None) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[doc_1, doc_2], + ) + + # Act + MetadataService.enable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is True + assert doc_1.doc_metadata is not None + assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" + assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert doc_2.doc_metadata is not None + assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_disable_built_in_field_should_return_immediately_when_already_disabled( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + get_docs.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document( + "1", + { + BuiltInField.document_name: "doc", + BuiltInField.uploader: "user", + BuiltInField.upload_date: 1.0, + BuiltInField.last_update_date: 2.0, + BuiltInField.source: MetadataDataSource.upload_file, + "custom": "keep", + }, + ) + mocker.patch( + "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", + return_value=[document], + ) + + # Act + MetadataService.disable_built_in_field(dataset) + + # Assert + assert dataset.built_in_field_enabled is False + assert document.doc_metadata == {"custom": "keep"} + mock_db.session.commit.assert_called_once() + mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") + + +def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + document = _build_document("1", {"legacy": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + delete_chain = mock_db.session.query.return_value.filter_by.return_value + delete_chain.delete.return_value = 1 + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata == {"priority": "high"} + delete_chain.delete.assert_called_once() + assert mock_db.session.commit.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mock_current_account: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=True) + document = _build_document("1", {"existing": "value"}) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + operation = DocumentMetadataOperation( + document_id="1", + metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + assert document.doc_metadata is not None + assert document.doc_metadata["existing"] == "value" + assert document.doc_metadata["new_key"] == "new_value" + assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file + assert mock_db.session.commit.call_count == 1 + assert mock_db.session.add.call_count == 1 + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") + mock_current_account.assert_called_once() + + +def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( + mock_db: MagicMock, + mock_redis_client: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + dataset = _dataset(id="dataset-1", built_in_field_enabled=False) + mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) + operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) + metadata_args = MetadataOperationData(operation_data=[operation]) + + # Act + Assert + with pytest.raises(ValueError, match="Document not found"): + MetadataService.update_documents_metadata(dataset, metadata_args) + + # Assert + mock_db.session.rollback.assert_called_once() + mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") + + +@pytest.mark.parametrize( + ("dataset_id", "document_id", "expected_key"), + [ + ("dataset-1", None, "dataset_metadata_lock_dataset-1"), + (None, "doc-1", "document_metadata_lock_doc-1"), + ], +) +def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( + dataset_id: str | None, + document_id: str | None, + expected_key: str, + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) + + # Assert + mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="knowledge base metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) + + +def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = 1 + + # Act + Assert + with pytest.raises(ValueError, match="document metadata operation is running"): + MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") + + +def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset( + id="dataset-1", + built_in_field_enabled=True, + doc_metadata=[ + {"id": "meta-1", "name": "priority", "type": "string"}, + {"id": "built-in", "name": "ignored", "type": "string"}, + {"id": "meta-2", "name": "score", "type": "number"}, + ], + ) + count_chain = mock_db.session.query.return_value.filter_by.return_value + count_chain.count.side_effect = [3, 1] + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result["built_in_field_enabled"] is True + assert result["doc_metadata"] == [ + {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, + {"id": "meta-2", "name": "score", "type": "number", "count": 1}, + ] + + +def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: + # Arrange + dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) + + # Act + result = MetadataService.get_dataset_metadatas(dataset) + + # Assert + assert result == {"doc_metadata": [], "built_in_field_enabled": False} + mock_db.session.query.assert_not_called() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py new file mode 100644 index 0000000000..49e572584b --- /dev/null +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -0,0 +1,808 @@ +from __future__ import annotations + +import json +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from dify_graph.model_runtime.entities.common_entities import I18nObject +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.entities.provider_entities import ( + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, +) +from models.provider import LoadBalancingModelConfig +from services.model_load_balancing_service import ModelLoadBalancingService + + +def _build_provider_credential_schema() -> ProviderCredentialSchema: + return ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ] + ) + + +def _build_model_credential_schema() -> ModelCredentialSchema: + return ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema(variable="api_key", label=I18nObject(en_US="API Key"), type=FormType.SECRET_INPUT) + ], + ) + + +def _build_provider_configuration( + *, + custom_provider: bool = False, + load_balancing_enabled: bool | None = None, + model_schema: ModelCredentialSchema | None = None, + provider_schema: ProviderCredentialSchema | None = None, +) -> MagicMock: + provider_configuration = MagicMock() + provider_configuration.provider = SimpleNamespace( + provider="openai", + model_credential_schema=model_schema, + provider_credential_schema=provider_schema, + ) + provider_configuration.custom_configuration = SimpleNamespace(provider=custom_provider) + provider_configuration.extract_secret_variables.return_value = ["api_key"] + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: credentials + provider_configuration.get_provider_model_setting.return_value = ( + None if load_balancing_enabled is None else SimpleNamespace(load_balancing_enabled=load_balancing_enabled) + ) + return provider_configuration + + +def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: + return cast(LoadBalancingModelConfig, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def service(mocker: MockerFixture) -> ModelLoadBalancingService: + # Arrange + provider_manager = MagicMock() + mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + svc = ModelLoadBalancingService() + svc.provider_manager = provider_manager + return svc + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.model_load_balancing_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +@pytest.mark.parametrize( + ("method_name", "expected_provider_method"), + [ + ("enable_model_load_balancing", "enable_model_load_balancing"), + ("disable_model_load_balancing", "disable_model_load_balancing"), + ], +) +def test_enable_disable_model_load_balancing_should_call_provider_configuration_method_when_provider_exists( + method_name: str, + expected_provider_method: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + # Assert + getattr(provider_configuration, expected_provider_method).assert_called_once_with( + model="gpt-4o-mini", model_type=ModelType.LLM + ) + + +@pytest.mark.parametrize( + "method_name", + ["enable_model_load_balancing", "disable_model_load_balancing"], +) +def test_enable_disable_model_load_balancing_should_raise_value_error_when_provider_missing( + method_name: str, + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + getattr(service, method_name)("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_configs("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value) + + +def test_get_load_balancing_configs_should_insert_inherit_config_when_missing_for_custom_provider( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=True, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace( + id="cfg-1", + name="primary", + encrypted_config=json.dumps({"api_key": "encrypted-key"}), + credential_id="cred-1", + enabled=True, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [config] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + return_value="plain-key", + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(False, 0), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + ) + + # Assert + assert is_enabled is True + assert len(configs) == 2 + assert configs[0]["name"] == "__inherit__" + assert configs[1]["name"] == "primary" + assert configs[1]["credentials"] == {"api_key": "plain-key"} + assert mock_db.session.add.call_count == 1 + assert mock_db.session.commit.call_count == 1 + + +def test_get_load_balancing_configs_should_reorder_existing_inherit_and_tolerate_json_or_decrypt_errors( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration( + custom_provider=True, + load_balancing_enabled=None, + provider_schema=_build_provider_credential_schema(), + ) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + normal_config = SimpleNamespace( + id="cfg-1", + name="normal", + encrypted_config=json.dumps({"api_key": "bad-encrypted"}), + credential_id="cred-1", + enabled=True, + ) + inherit_config = SimpleNamespace( + id="cfg-2", + name="__inherit__", + encrypted_config="not-json", + credential_id=None, + enabled=False, + ) + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + normal_config, + inherit_config, + ] + mocker.patch( + "services.model_load_balancing_service.encrypter.get_decrypt_decoding", + return_value=("rsa", "cipher"), + ) + mocker.patch( + "services.model_load_balancing_service.encrypter.decrypt_token_with_decoding", + side_effect=ValueError("cannot decrypt"), + ) + mocker.patch( + "services.model_load_balancing_service.LBModelManager.get_config_in_cooldown_and_ttl", + return_value=(True, 15), + ) + + # Act + is_enabled, configs = service.get_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + config_from="predefined-model", + ) + + # Assert + assert is_enabled is False + assert configs[0]["name"] == "__inherit__" + assert configs[0]["credentials"] == {} + assert configs[1]["credentials"] == {"api_key": "bad-encrypted"} + assert configs[1]["in_cooldown"] is True + assert configs[1]["ttl"] == 15 + + +def test_get_load_balancing_config_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + +def test_get_load_balancing_config_should_return_none_when_config_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result is None + + +def test_get_load_balancing_config_should_return_obfuscated_payload_when_config_exists( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + provider_configuration.obfuscated_credentials.side_effect = lambda credentials, credential_form_schemas: { + "masked": credentials.get("api_key", "") + } + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + config = SimpleNamespace(id="cfg-1", name="primary", encrypted_config="not-json", enabled=True) + mock_db.session.query.return_value.where.return_value.first.return_value = config + + # Act + result = service.get_load_balancing_config("tenant-1", "openai", "gpt-4o-mini", ModelType.LLM.value, "cfg-1") + + # Assert + assert result == { + "id": "cfg-1", + "name": "primary", + "credentials": {"masked": ""}, + "enabled": True, + } + + +def test_init_inherit_config_should_create_and_persist_inherit_configuration( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + model_type = ModelType.LLM + + # Act + inherit_config = service._init_inherit_config("tenant-1", "openai", "gpt-4o-mini", model_type) + + # Assert + assert inherit_config.tenant_id == "tenant-1" + assert inherit_config.provider_name == "openai" + assert inherit_config.model_name == "gpt-4o-mini" + assert inherit_config.model_type == "text-generation" + assert inherit_config.name == "__inherit__" + mock_db.session.add.assert_called_once_with(inherit_config) + mock_db.session.commit.assert_called_once() + + +def test_update_load_balancing_configs_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_configs_is_not_list( + service: ModelLoadBalancingService, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing configs"): + service.update_load_balancing_configs( # type: ignore[arg-type] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], "invalid-configs"), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_config_item_is_not_dict( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config"): + service.update_load_balancing_configs( # type: ignore[list-item] + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + cast(list[dict[str, object]], ["bad-item"]), + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credential_id_not_found( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Provider credential with id cred-1 not found"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_name_or_enabled_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"enabled": True}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config enabled"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "cfg-without-enabled"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_existing_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + current_config = SimpleNamespace(id="cfg-1") + mock_db.session.scalars.return_value.all.return_value = [current_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config id: cfg-2"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-2", "name": "invalid", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_raise_value_error_when_credentials_are_invalid_for_update_or_create( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1", name="old", enabled=True, encrypted_config=None, updated_at=None) + mock_db.session.scalars.return_value.all.return_value = [existing_config] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"id": "cfg-1", "name": "new", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new-config", "enabled": True, "credentials": "bad"}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_update_existing_create_new_and_delete_removed_configs( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config_1 = SimpleNamespace( + id="cfg-1", + name="existing-one", + enabled=True, + encrypted_config=json.dumps({"api_key": "old"}), + updated_at=None, + ) + existing_config_2 = SimpleNamespace( + id="cfg-2", + name="existing-two", + enabled=True, + encrypted_config=None, + updated_at=None, + ) + mock_db.session.scalars.return_value.all.return_value = [existing_config_1, existing_config_2] + mocker.patch.object(service, "_custom_credentials_validate", return_value={"api_key": "encrypted"}) + mock_clear_cache = mocker.patch.object(service, "_clear_credentials_cache") + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [ + {"id": "cfg-1", "name": "updated-name", "enabled": False, "credentials": {"api_key": "plain"}}, + {"name": "new-config", "enabled": True, "credentials": {"api_key": "plain"}}, + ], + "custom-model", + ) + + # Assert + assert existing_config_1.name == "updated-name" + assert existing_config_1.enabled is False + assert json.loads(existing_config_1.encrypted_config) == {"api_key": "encrypted"} + assert mock_db.session.add.call_count == 1 + mock_db.session.delete.assert_called_once_with(existing_config_2) + assert mock_db.session.commit.call_count >= 3 + mock_clear_cache.assert_any_call("tenant-1", "cfg-1") + mock_clear_cache.assert_any_call("tenant-1", "cfg-2") + + +def test_update_load_balancing_configs_should_raise_value_error_for_invalid_new_config_name_or_missing_credentials( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + + # Act + Assert + with pytest.raises(ValueError, match="Invalid load balancing config name"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "__inherit__", "enabled": True, "credentials": {"api_key": "x"}}], + "custom-model", + ) + + with pytest.raises(ValueError, match="Invalid load balancing config credentials"): + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"name": "new", "enabled": True}], + "custom-model", + ) + + +def test_update_load_balancing_configs_should_create_from_existing_provider_credential_when_credential_id_provided( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.scalars.return_value.all.return_value = [] + credential_record = SimpleNamespace(credential_name="Main Credential", encrypted_config='{"api_key":"enc"}') + mock_db.session.query.return_value.filter_by.return_value.first.return_value = credential_record + + # Act + service.update_load_balancing_configs( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + [{"credential_id": "cred-1", "enabled": True}], + "predefined-model", + ) + + # Assert + created_config = mock_db.session.add.call_args.args[0] + assert created_config.name == "Main Credential" + assert created_config.credential_id == "cred-1" + assert created_config.credential_source_type == "provider" + assert created_config.encrypted_config == '{"api_key":"enc"}' + mock_db.session.commit.assert_called() + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_provider_missing( + service: ModelLoadBalancingService, +) -> None: + # Arrange + service.provider_manager.get_configurations.return_value = {} + + # Act + Assert + with pytest.raises(ValueError, match="Provider openai does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + +def test_validate_load_balancing_credentials_should_raise_value_error_when_config_id_is_invalid( + service: ModelLoadBalancingService, + mock_db: MagicMock, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Load balancing config cfg-1 does not exist"): + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + + +def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_with_or_without_config( + service: ModelLoadBalancingService, + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + service.provider_manager.get_configurations.return_value = {"openai": provider_configuration} + existing_config = SimpleNamespace(id="cfg-1") + mock_db.session.query.return_value.where.return_value.first.return_value = existing_config + mock_validate = mocker.patch.object(service, "_custom_credentials_validate") + + # Act + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + config_id="cfg-1", + ) + service.validate_load_balancing_credentials( + "tenant-1", + "openai", + "gpt-4o-mini", + ModelType.LLM.value, + {"api_key": "plain"}, + ) + + # Assert + assert mock_validate.call_count == 2 + assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config + assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + + +def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + load_balancing_model_config = _load_balancing_model_config( + encrypted_config=json.dumps({"api_key": "old-encrypted-token"}) + ) + mocker.patch("services.model_load_balancing_service.encrypter.decrypt_token", return_value="old-plain-value") + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + load_balancing_model_config=load_balancing_model_config, + validate=False, + ) + + # Assert + assert result == {"api_key": "enc:old-plain-value", "region": "us"} + mock_encrypt.assert_called_once_with("tenant-1", "old-plain-value") + + +def test_custom_credentials_validate_should_handle_invalid_original_json_and_validate_with_model_schema( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(model_schema=_build_model_credential_schema()) + load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") + mock_factory = MagicMock() + mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mock_encrypt = mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + load_balancing_model_config=load_balancing_model_config, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:validated"} + mock_factory.model_credentials_validate.assert_called_once() + mock_factory.provider_credentials_validate.assert_not_called() + mock_encrypt.assert_called_once_with("tenant-1", "validated") + + +def test_custom_credentials_validate_should_validate_with_provider_schema_when_model_schema_absent( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) + mock_factory = MagicMock() + mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} + mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) + mocker.patch( + "services.model_load_balancing_service.encrypter.encrypt_token", + side_effect=lambda tenant_id, value: f"enc:{value}", + ) + + # Act + result = service._custom_credentials_validate( + tenant_id="tenant-1", + provider_configuration=provider_configuration, + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "plain"}, + validate=True, + ) + + # Assert + assert result == {"api_key": "enc:provider-validated"} + mock_factory.provider_credentials_validate.assert_called_once() + mock_factory.model_credentials_validate.assert_not_called() + + +def test_get_credential_schema_should_return_model_schema_or_provider_schema_or_raise( + service: ModelLoadBalancingService, +) -> None: + # Arrange + model_schema = _build_model_credential_schema() + provider_schema = _build_provider_credential_schema() + provider_configuration_with_model = _build_provider_configuration(model_schema=model_schema) + provider_configuration_with_provider = _build_provider_configuration(provider_schema=provider_schema) + provider_configuration_without_schema = _build_provider_configuration() + + # Act + schema_from_model = service._get_credential_schema(provider_configuration_with_model) + schema_from_provider = service._get_credential_schema(provider_configuration_with_provider) + + # Assert + assert schema_from_model is model_schema + assert schema_from_provider is provider_schema + with pytest.raises(ValueError, match="No credential schema found"): + service._get_credential_schema(provider_configuration_without_schema) + + +def test_clear_credentials_cache_should_delete_load_balancing_cache_entry( + service: ModelLoadBalancingService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_cache_instance = MagicMock() + mock_cache_cls = mocker.patch( + "services.model_load_balancing_service.ProviderCredentialsCache", + return_value=mock_cache_instance, + ) + + # Act + service._clear_credentials_cache("tenant-1", "cfg-1") + + # Assert + mock_cache_cls.assert_called_once() + assert mock_cache_cls.call_args.kwargs == { + "tenant_id": "tenant-1", + "identity_id": "cfg-1", + "cache_type": mocker.ANY, + } + assert mock_cache_cls.call_args.kwargs["cache_type"].name == "LOAD_BALANCING_MODEL" + mock_cache_instance.delete.assert_called_once() diff --git a/api/tests/unit_tests/services/test_oauth_server_service.py b/api/tests/unit_tests/services/test_oauth_server_service.py new file mode 100644 index 0000000000..231ceb74dc --- /dev/null +++ b/api/tests/unit_tests/services/test_oauth_server_service.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import uuid +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from werkzeug.exceptions import BadRequest + +from services.oauth_server import ( + OAUTH_ACCESS_TOKEN_EXPIRES_IN, + OAUTH_ACCESS_TOKEN_REDIS_KEY, + OAUTH_AUTHORIZATION_CODE_REDIS_KEY, + OAUTH_REFRESH_TOKEN_EXPIRES_IN, + OAUTH_REFRESH_TOKEN_REDIS_KEY, + OAuthGrantType, + OAuthServerService, +) + + +@pytest.fixture +def mock_redis_client(mocker: MockerFixture) -> MagicMock: + return mocker.patch("services.oauth_server.redis_client") + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mock the OAuth server Session context manager.""" + mocker.patch("services.oauth_server.db", SimpleNamespace(engine=object())) + session = MagicMock() + session_cm = MagicMock() + session_cm.__enter__.return_value = session + mocker.patch("services.oauth_server.Session", return_value=session_cm) + return session + + +def test_get_oauth_provider_app_should_return_app_when_record_exists(mock_session: MagicMock) -> None: + # Arrange + mock_execute_result = MagicMock() + expected_app = MagicMock() + mock_execute_result.scalar_one_or_none.return_value = expected_app + mock_session.execute.return_value = mock_execute_result + + # Act + result = OAuthServerService.get_oauth_provider_app("client-1") + + # Assert + assert result is expected_app + mock_session.execute.assert_called_once() + mock_execute_result.scalar_one_or_none.assert_called_once() + + +def test_sign_oauth_authorization_code_should_store_code_and_return_value( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000111") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + + # Act + code = OAuthServerService.sign_oauth_authorization_code("client-1", "user-1") + + # Assert + expected_code = str(deterministic_uuid) + assert code == expected_code + mock_redis_client.set.assert_called_once_with( + OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code=expected_code), + "user-1", + ex=600, + ) + + +def test_sign_oauth_access_token_should_raise_bad_request_when_authorization_code_is_invalid( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + Assert + with pytest.raises(BadRequest, match="invalid code"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="bad-code", + client_id="client-1", + ) + + +def test_sign_oauth_access_token_should_issue_access_and_refresh_token_when_authorization_code_is_valid( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + token_uuids = [ + uuid.UUID("00000000-0000-0000-0000-000000000201"), + uuid.UUID("00000000-0000-0000-0000-000000000202"), + ] + mocker.patch("services.oauth_server.uuid.uuid4", side_effect=token_uuids) + mock_redis_client.get.return_value = b"user-1" + code_key = OAUTH_AUTHORIZATION_CODE_REDIS_KEY.format(client_id="client-1", code="code-1") + + # Act + access_token, refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.AUTHORIZATION_CODE, + code="code-1", + client_id="client-1", + ) + + # Assert + assert access_token == str(token_uuids[0]) + assert refresh_token == str(token_uuids[1]) + mock_redis_client.delete.assert_called_once_with(code_key) + mock_redis_client.set.assert_any_call( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + mock_redis_client.set.assert_any_call( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-1", token=refresh_token), + b"user-1", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + +def test_sign_oauth_access_token_should_raise_bad_request_when_refresh_token_is_invalid( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + Assert + with pytest.raises(BadRequest, match="invalid refresh token"): + OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="stale-token", + client_id="client-1", + ) + + +def test_sign_oauth_access_token_should_issue_new_access_token_when_refresh_token_is_valid( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000301") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + mock_redis_client.get.return_value = b"user-1" + + # Act + access_token, returned_refresh_token = OAuthServerService.sign_oauth_access_token( + grant_type=OAuthGrantType.REFRESH_TOKEN, + refresh_token="refresh-1", + client_id="client-1", + ) + + # Assert + assert access_token == str(deterministic_uuid) + assert returned_refresh_token == "refresh-1" + mock_redis_client.set.assert_called_once_with( + OAUTH_ACCESS_TOKEN_REDIS_KEY.format(client_id="client-1", token=access_token), + b"user-1", + ex=OAUTH_ACCESS_TOKEN_EXPIRES_IN, + ) + + +def test_sign_oauth_access_token_with_unknown_grant_type_should_return_none() -> None: + # Arrange + grant_type = cast(OAuthGrantType, "invalid-grant-type") + + # Act + result = OAuthServerService.sign_oauth_access_token( + grant_type=grant_type, + client_id="client-1", + ) + + # Assert + assert result is None + + +def test_sign_oauth_refresh_token_should_store_token_with_expected_expiry( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + deterministic_uuid = uuid.UUID("00000000-0000-0000-0000-000000000401") + mocker.patch("services.oauth_server.uuid.uuid4", return_value=deterministic_uuid) + + # Act + refresh_token = OAuthServerService._sign_oauth_refresh_token("client-2", "user-2") + + # Assert + assert refresh_token == str(deterministic_uuid) + mock_redis_client.set.assert_called_once_with( + OAUTH_REFRESH_TOKEN_REDIS_KEY.format(client_id="client-2", token=refresh_token), + "user-2", + ex=OAUTH_REFRESH_TOKEN_EXPIRES_IN, + ) + + +def test_validate_oauth_access_token_should_return_none_when_token_not_found( + mock_redis_client: MagicMock, +) -> None: + # Arrange + mock_redis_client.get.return_value = None + + # Act + result = OAuthServerService.validate_oauth_access_token("client-1", "missing-token") + + # Assert + assert result is None + + +def test_validate_oauth_access_token_should_load_user_when_token_exists( + mocker: MockerFixture, mock_redis_client: MagicMock +) -> None: + # Arrange + mock_redis_client.get.return_value = b"user-88" + expected_user = MagicMock() + mock_load_user = mocker.patch("services.oauth_server.AccountService.load_user", return_value=expected_user) + + # Act + result = OAuthServerService.validate_oauth_access_token("client-1", "access-token") + + # Assert + assert result is expected_user + mock_load_user.assert_called_once_with("user-88") diff --git a/api/tests/unit_tests/services/test_trigger_provider_service.py b/api/tests/unit_tests/services/test_trigger_provider_service.py new file mode 100644 index 0000000000..81a3b181fd --- /dev/null +++ b/api/tests/unit_tests/services/test_trigger_provider_service.py @@ -0,0 +1,1249 @@ +from __future__ import annotations + +import contextlib +import json +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE +from core.plugin.entities.plugin_daemon import CredentialType +from models.provider_ids import TriggerProviderID +from services.trigger.trigger_provider_service import TriggerProviderService + + +def _patch_redis_lock(mocker: MockerFixture) -> None: + mock_redis = mocker.patch("services.trigger.trigger_provider_service.redis_client") + mock_redis.lock.return_value = contextlib.nullcontext() + + +def _mock_get_trigger_provider(mocker: MockerFixture, provider: object | None) -> None: + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.get_trigger_provider", + return_value=provider, + ) + + +def _encrypter_mock( + *, + decrypted: dict | None = None, + encrypted: dict | None = None, + masked: dict | None = None, +) -> MagicMock: + enc = MagicMock() + enc.decrypt.return_value = decrypted or {} + enc.encrypt.return_value = encrypted or {} + enc.mask_credentials.return_value = masked or {} + enc.mask_plugin_credentials.return_value = masked or {} + return enc + + +@pytest.fixture +def provider_id() -> TriggerProviderID: + # Arrange + return TriggerProviderID("langgenius/github/github") + + +@pytest.fixture(autouse=True) +def mock_db_engine(mocker: MockerFixture) -> SimpleNamespace: + # Arrange + mocked_db = SimpleNamespace(engine=object()) + mocker.patch("services.trigger.trigger_provider_service.db", mocked_db) + return mocked_db + + +@pytest.fixture +def mock_session(mocker: MockerFixture) -> MagicMock: + """Mocks the database session context manager used by TriggerProviderService.""" + # Arrange + mock_session_instance = MagicMock() + mock_session_cm = MagicMock() + mock_session_cm.__enter__.return_value = mock_session_instance + mock_session_cm.__exit__.return_value = False + mocker.patch("services.trigger.trigger_provider_service.Session", return_value=mock_session_cm) + return mock_session_instance + + +@pytest.fixture +def provider_controller() -> MagicMock: + # Arrange + controller = MagicMock() + controller.get_credential_schema_config.return_value = [] + controller.get_properties_schema.return_value = [] + controller.get_oauth_client_schema.return_value = [] + controller.plugin_unique_identifier = "langgenius/github:0.0.1" + return controller + + +def test_get_trigger_provider_should_return_api_entity_from_manager( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + provider = MagicMock() + provider.to_api_entity.return_value = {"provider": "ok"} + _mock_get_trigger_provider(mocker, provider) + + # Act + result = TriggerProviderService.get_trigger_provider("tenant-1", provider_id) + + # Assert + assert result == {"provider": "ok"} + + +def test_list_trigger_providers_should_return_api_entities_from_manager(mocker: MockerFixture) -> None: + # Arrange + provider_a = MagicMock() + provider_b = MagicMock() + provider_a.to_api_entity.return_value = {"id": "a"} + provider_b.to_api_entity.return_value = {"id": "b"} + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.list_all_trigger_providers", + return_value=[provider_a, provider_b], + ) + + # Act + result = TriggerProviderService.list_trigger_providers("tenant-1") + + # Assert + assert result == [{"id": "a"}, {"id": "b"}] + + +def test_list_trigger_provider_subscriptions_should_return_empty_list_when_no_subscriptions( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.order_by.return_value.all.return_value = [] + mock_session.query.return_value = query + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert result == [] + + +def test_list_trigger_provider_subscriptions_should_mask_fields_and_attach_workflow_counts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + api_sub = SimpleNamespace( + id="sub-1", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + parameters={"event": "push"}, + workflows_in_use=0, + ) + db_sub = SimpleNamespace(to_api_entity=lambda: api_sub) + usage_row = SimpleNamespace(subscription_id="sub-1", app_count=2) + + query_subs = MagicMock() + query_subs.filter_by.return_value.order_by.return_value.all.return_value = [db_sub] + query_usage = MagicMock() + query_usage.filter.return_value.group_by.return_value.all.return_value = [usage_row] + mock_session.query.side_effect = [query_subs, query_usage] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}, masked={"token": "****"}) + prop_enc = _encrypter_mock(decrypted={"hook": "plain"}, masked={"hook": "****"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.list_trigger_provider_subscriptions("tenant-1", provider_id) + + # Assert + assert len(result) == 1 + assert result[0].credentials == {"token": "****"} + assert result[0].properties == {"hook": "****"} + assert result[0].workflows_in_use == 2 + + +def test_add_trigger_subscription_should_create_subscription_successfully_for_api_key( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(encrypted={"api_key": "enc"}) + prop_enc = _encrypter_mock(encrypted={"project": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(cred_enc, MagicMock()), (prop_enc, MagicMock())], + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={"event": "push"}, + properties={"project": "demo"}, + credentials={"api_key": "plain"}, + ) + + # Assert + assert result["result"] == "success" + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_add_trigger_subscription_should_store_empty_credentials_for_unauthorized_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_count, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(encrypted={"p": "enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.UNAUTHORIZED, + parameters={}, + properties={"p": "v"}, + credentials={}, + subscription_id="sub-fixed", + ) + + # Assert + assert result == {"result": "success", "id": "sub-fixed"} + + +def test_add_trigger_subscription_should_raise_error_when_provider_limit_reached( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = TriggerProviderService.__MAX_TRIGGER_PROVIDER_COUNT__ + mock_session.query.return_value = query_count + _mock_get_trigger_provider(mocker, provider_controller) + mock_logger = mocker.patch("services.trigger.trigger_provider_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="Maximum number of providers"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + mock_logger.exception.assert_called_once() + + +def test_add_trigger_subscription_should_raise_error_when_name_exists( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_count = MagicMock() + query_count.filter_by.return_value.count.return_value = 0 + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_count, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="Credential name 'main' already exists"): + TriggerProviderService.add_trigger_subscription( + tenant_id="tenant-1", + user_id="user-1", + name="main", + provider_id=provider_id, + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY, + parameters={}, + properties={}, + credentials={}, + ) + + +def test_update_trigger_subscription_should_raise_error_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query_sub + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1") + + +def test_update_trigger_subscription_should_raise_error_when_name_conflicts( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + provider_id="langgenius/github/github", + credential_type=CredentialType.API_KEY.value, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = object() + mock_session.query.side_effect = [query_sub, query_existing] + _mock_get_trigger_provider(mocker, provider_controller) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + TriggerProviderService.update_trigger_subscription("tenant-1", "sub-1", name="new-name") + + +def test_update_trigger_subscription_should_update_fields_and_clear_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + _patch_redis_lock(mocker) + subscription = SimpleNamespace( + id="sub-1", + name="old", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + properties={"project": "enc-old"}, + parameters={"event": "old"}, + credentials={"api_key": "enc-old"}, + credential_type=CredentialType.API_KEY.value, + credential_expires_at=0, + expires_at=0, + ) + query_sub = MagicMock() + query_sub.filter_by.return_value.first.return_value = subscription + query_existing = MagicMock() + query_existing.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_sub, query_existing] + + _mock_get_trigger_provider(mocker, provider_controller) + prop_enc = _encrypter_mock(decrypted={"project": "old-value"}, encrypted={"project": "new-value"}) + cred_enc = _encrypter_mock(encrypted={"api_key": "new-key"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + side_effect=[(prop_enc, MagicMock()), (cred_enc, MagicMock())], + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.update_trigger_subscription( + tenant_id="tenant-1", + subscription_id="sub-1", + name="new", + properties={"project": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "new"}, + credentials={"api_key": "plain-key"}, + credential_expires_at=100, + expires_at=200, + ) + + # Assert + assert subscription.name == "new" + assert subscription.parameters == {"event": "new"} + assert subscription.credentials == {"api_key": "new-key"} + assert subscription.credential_expires_at == 100 + assert subscription.expires_at == 200 + mock_session.commit.assert_called_once() + mock_delete_cache.assert_called_once() + + +def test_get_subscription_by_id_should_return_none_when_missing(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is None + + +def test_get_subscription_by_id_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"project": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + prop_enc = _encrypter_mock(decrypted={"project": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_id("tenant-1", "sub-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"project": "plain"} + + +def test_delete_trigger_provider_should_raise_error_when_subscription_missing( + mocker: MockerFixture, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + +def test_delete_trigger_provider_should_delete_and_clear_cache_even_if_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.OAUTH2.value, + credentials={"token": "enc"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"token": "plain"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + side_effect=RuntimeError("remote fail"), + ) + mock_delete_cache = mocker.patch("services.trigger.trigger_provider_service.delete_cache_for_subscription") + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-1") + + # Assert + mock_session.delete.assert_called_once_with(subscription) + mock_delete_cache.assert_called_once() + + +def test_delete_trigger_provider_should_skip_unsubscribe_for_unauthorized( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-2", + user_id="user-1", + provider_id=str(provider_id), + credential_type=CredentialType.UNAUTHORIZED.value, + credentials={}, + to_entity=lambda: SimpleNamespace(id="sub-2"), + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mock_unsubscribe = mocker.patch("services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger") + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={}), MagicMock()), + ) + + # Act + TriggerProviderService.delete_trigger_provider(mock_session, "tenant-1", "sub-2") + + # Assert + mock_unsubscribe.assert_not_called() + mock_session.delete.assert_called_once_with(subscription) + + +def test_refresh_oauth_token_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_raise_error_for_non_oauth_credentials( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + Assert + with pytest.raises(ValueError, match="Only OAuth credentials can be refreshed"): + TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + +def test_refresh_oauth_token_should_refresh_and_persist_new_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + provider_id=str(provider_id), + user_id="user-1", + credential_type=CredentialType.OAUTH2.value, + credentials={"access_token": "enc"}, + credential_expires_at=0, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + cred_enc = _encrypter_mock(decrypted={"access_token": "old"}, encrypted={"access_token": "new"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(cred_enc, cache), + ) + mocker.patch.object(TriggerProviderService, "get_oauth_client", return_value={"client_id": "id"}) + refreshed = SimpleNamespace(credentials={"access_token": "new"}, expires_at=12345) + oauth_handler = MagicMock() + oauth_handler.refresh_credentials.return_value = refreshed + mocker.patch("services.trigger.trigger_provider_service.OAuthHandler", return_value=oauth_handler) + + # Act + result = TriggerProviderService.refresh_oauth_token("tenant-1", "sub-1") + + # Assert + assert result == {"result": "success", "expires_at": 12345} + assert subscription.credentials == {"access_token": "new"} + assert subscription.credential_expires_at == 12345 + mock_session.commit.assert_called_once() + cache.delete.assert_called_once() + + +def test_refresh_subscription_should_raise_error_when_subscription_missing( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + +def test_refresh_subscription_should_skip_when_not_due(mocker: MockerFixture, mock_session: MagicMock) -> None: + # Arrange + subscription = SimpleNamespace(expires_at=200) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "skipped", "expires_at": 200} + + +def test_refresh_subscription_should_refresh_and_persist_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + tenant_id="tenant-1", + endpoint_id="endpoint-1", + expires_at=50, + provider_id=str(provider_id), + parameters={"event": "push"}, + properties={"p": "enc"}, + credentials={"c": "enc"}, + credential_type=CredentialType.API_KEY.value, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + cred_enc = _encrypter_mock(decrypted={"c": "plain"}) + prop_cache = MagicMock() + prop_enc = _encrypter_mock(decrypted={"p": "plain"}, encrypted={"p": "new-enc"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(cred_enc, MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(prop_enc, prop_cache), + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + provider_controller.refresh_trigger.return_value = SimpleNamespace(properties={"p": "new"}, expires_at=999) + + # Act + result = TriggerProviderService.refresh_subscription("tenant-1", "sub-1", now=100) + + # Assert + assert result == {"result": "success", "expires_at": 999} + assert subscription.properties == {"p": "new-enc"} + assert subscription.expires_at == 999 + mock_session.commit.assert_called_once() + prop_cache.delete.assert_called_once() + + +def test_get_oauth_client_should_return_tenant_client_when_available( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + tenant_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + system_client = None + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = tenant_client + mock_session.query.return_value = query_tenant + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "plain"} + + +def test_get_oauth_client_should_return_none_when_plugin_not_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = None + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result is None + + +def test_get_oauth_client_should_return_decrypted_system_client_when_verified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + return_value={"client_id": "system"}, + ) + + # Act + result = TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "system"} + + +def test_get_oauth_client_should_raise_error_when_system_decryption_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query_tenant = MagicMock() + query_tenant.filter_by.return_value.first.return_value = None + query_system = MagicMock() + query_system.filter_by.return_value.first.return_value = SimpleNamespace(encrypted_oauth_params="enc") + mock_session.query.side_effect = [query_tenant, query_system] + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + mocker.patch( + "services.trigger.trigger_provider_service.decrypt_system_oauth_params", + side_effect=RuntimeError("bad data"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Error decrypting system oauth params"): + TriggerProviderService.get_oauth_client("tenant-1", provider_id) + + +def test_is_oauth_system_client_exists_should_return_false_when_unverified( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=False) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is False + + +@pytest.mark.parametrize("has_client", [True, False]) +def test_is_oauth_system_client_exists_should_reflect_database_record( + has_client: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if has_client else None + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch("services.trigger.trigger_provider_service.PluginService.is_plugin_verified", return_value=True) + + # Act + result = TriggerProviderService.is_oauth_system_client_exists("tenant-1", provider_id) + + # Assert + assert result is has_client + + +def test_save_custom_oauth_client_params_should_return_success_when_nothing_to_update( + provider_id: TriggerProviderID, +) -> None: + # Arrange + # Act + result = TriggerProviderService.save_custom_oauth_client_params("tenant-1", provider_id, None, None) + + # Assert + assert result == {"result": "success"} + + +def test_save_custom_oauth_client_params_should_create_record_and_clear_params_when_client_params_none( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + query = MagicMock() + query.filter_by.return_value.first.return_value = None + mock_session.query.return_value = query + _mock_get_trigger_provider(mocker, provider_controller) + fake_model = SimpleNamespace(encrypted_oauth_params="", enabled=False, oauth_params={}) + mocker.patch("services.trigger.trigger_provider_service.TriggerOAuthTenantClient", return_value=fake_model) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params=None, + enabled=True, + ) + + # Assert + assert result == {"result": "success"} + assert fake_model.encrypted_oauth_params == "{}" + assert fake_model.enabled is True + mock_session.add.assert_called_once_with(fake_model) + mock_session.commit.assert_called_once() + + +def test_save_custom_oauth_client_params_should_merge_hidden_values_and_delete_cache( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc-old"}, enabled=False) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + cache = MagicMock() + enc = _encrypter_mock(decrypted={"client_id": "old-id"}, encrypted={"client_id": "new-id"}) + mocker.patch( + "services.trigger.trigger_provider_service.create_provider_encrypter", + return_value=(enc, cache), + ) + + # Act + result = TriggerProviderService.save_custom_oauth_client_params( + tenant_id="tenant-1", + provider_id=provider_id, + client_params={"client_id": HIDDEN_VALUE, "client_secret": "new"}, + enabled=None, + ) + + # Assert + assert result == {"result": "success"} + assert json.loads(custom_client.encrypted_oauth_params) == {"client_id": "new-id"} + cache.delete.assert_called_once() + mock_session.commit.assert_called_once() + + +def test_get_custom_oauth_client_params_should_return_empty_when_record_missing( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {} + + +def test_get_custom_oauth_client_params_should_return_masked_decrypted_values( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + custom_client = SimpleNamespace(oauth_params={"client_id": "enc"}) + mock_session.query.return_value.filter_by.return_value.first.return_value = custom_client + _mock_get_trigger_provider(mocker, provider_controller) + enc = _encrypter_mock(decrypted={"client_id": "plain"}, masked={"client_id": "pl***id"}) + mocker.patch("services.trigger.trigger_provider_service.create_provider_encrypter", return_value=(enc, MagicMock())) + + # Act + result = TriggerProviderService.get_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"client_id": "pl***id"} + + +def test_delete_custom_oauth_client_params_should_delete_record_and_commit( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.delete.return_value = 1 + + # Act + result = TriggerProviderService.delete_custom_oauth_client_params("tenant-1", provider_id) + + # Assert + assert result == {"result": "success"} + mock_session.commit.assert_called_once() + + +@pytest.mark.parametrize("exists", [True, False]) +def test_is_oauth_custom_client_enabled_should_return_expected_boolean( + exists: bool, + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = object() if exists else None + + # Act + result = TriggerProviderService.is_oauth_custom_client_enabled("tenant-1", provider_id) + + # Assert + assert result is exists + + +def test_get_subscription_by_endpoint_should_return_none_when_not_found( + mocker: MockerFixture, mock_session: MagicMock +) -> None: + # Arrange + mock_session.query.return_value.filter_by.return_value.first.return_value = None + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is None + + +def test_get_subscription_by_endpoint_should_decrypt_credentials_and_properties( + mocker: MockerFixture, + mock_session: MagicMock, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + tenant_id="tenant-1", + provider_id="langgenius/github/github", + credentials={"token": "enc"}, + properties={"hook": "enc"}, + ) + mock_session.query.return_value.filter_by.return_value.first.return_value = subscription + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_subscription", + return_value=(_encrypter_mock(decrypted={"token": "plain"}), MagicMock()), + ) + mocker.patch( + "services.trigger.trigger_provider_service.create_trigger_provider_encrypter_for_properties", + return_value=(_encrypter_mock(decrypted={"hook": "plain"}), MagicMock()), + ) + + # Act + result = TriggerProviderService.get_subscription_by_endpoint("endpoint-1") + + # Assert + assert result is subscription + assert subscription.credentials == {"token": "plain"} + assert subscription.properties == {"hook": "plain"} + + +def test_verify_subscription_credentials_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + +def test_verify_subscription_credentials_should_raise_when_api_key_validation_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + provider_controller.validate_credentials.side_effect = RuntimeError("bad credentials") + + # Act + Assert + with pytest.raises(ValueError, match="Invalid credentials: bad credentials"): + TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + +def test_verify_subscription_credentials_should_return_verified_when_api_key_validation_succeeds( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.API_KEY.value, credentials={"api_key": "old"}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE}, + ) + + # Assert + assert result == {"verified": True} + + +def test_verify_subscription_credentials_should_return_verified_for_non_api_key_credentials( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.OAUTH2.value, credentials={}) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + result = TriggerProviderService.verify_subscription_credentials( + tenant_id="tenant-1", + user_id="user-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + ) + + # Assert + assert result == {"verified": True} + + +def test_rebuild_trigger_subscription_should_raise_when_provider_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, None) + + # Act + Assert + with pytest.raises(ValueError, match="Provider .* not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_subscription_not_found( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="Subscription sub-1 not found"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_for_unsupported_credential_type( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace(credential_type=CredentialType.UNAUTHORIZED.value) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + + # Act + Assert + with pytest.raises(ValueError, match="not supported for auto creation"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_raise_when_unsubscribe_fails( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=False, message="remote error"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to delete previous subscription"): + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={}, + parameters={}, + ) + + +def test_rebuild_trigger_subscription_should_resubscribe_and_update_existing_subscription( + mocker: MockerFixture, + mock_session: MagicMock, + provider_id: TriggerProviderID, + provider_controller: MagicMock, +) -> None: + # Arrange + subscription = SimpleNamespace( + id="sub-1", + user_id="user-1", + endpoint_id="endpoint-1", + credential_type=CredentialType.API_KEY.value, + credentials={"api_key": "old-key"}, + to_entity=lambda: SimpleNamespace(id="sub-1"), + ) + new_subscription = SimpleNamespace(properties={"project": "new"}, expires_at=888) + _mock_get_trigger_provider(mocker, provider_controller) + mocker.patch.object(TriggerProviderService, "get_subscription_by_id", return_value=subscription) + mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.unsubscribe_trigger", + return_value=SimpleNamespace(success=True, message="ok"), + ) + mock_subscribe = mocker.patch( + "services.trigger.trigger_provider_service.TriggerManager.subscribe_trigger", + return_value=new_subscription, + ) + mocker.patch( + "services.trigger.trigger_provider_service.generate_plugin_trigger_endpoint_url", + return_value="https://endpoint", + ) + mock_update = mocker.patch.object(TriggerProviderService, "update_trigger_subscription") + + # Act + TriggerProviderService.rebuild_trigger_subscription( + tenant_id="tenant-1", + provider_id=provider_id, + subscription_id="sub-1", + credentials={"api_key": HIDDEN_VALUE, "region": "us"}, + parameters={"event": "push"}, + name="updated", + ) + + # Assert + call_kwargs = mock_subscribe.call_args.kwargs + assert call_kwargs["credentials"]["api_key"] == "old-key" + assert call_kwargs["credentials"]["region"] == "us" + mock_update.assert_called_once_with( + tenant_id="tenant-1", + subscription_id="sub-1", + name="updated", + parameters={"event": "push"}, + credentials={"api_key": "old-key", "region": "us"}, + properties={"project": "new"}, + expires_at=888, + ) diff --git a/api/tests/unit_tests/services/test_web_conversation_service.py b/api/tests/unit_tests/services/test_web_conversation_service.py new file mode 100644 index 0000000000..7687d355e9 --- /dev/null +++ b/api/tests/unit_tests/services/test_web_conversation_service.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.app.entities.app_invoke_entities import InvokeFrom +from models import Account +from models.model import App, EndUser +from services.web_conversation_service import WebConversationService + + +@pytest.fixture +def app_model() -> App: + return cast(App, SimpleNamespace(id="app-1")) + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +def _end_user(**kwargs: Any) -> EndUser: + return cast(EndUser, SimpleNamespace(**kwargs)) + + +def test_pagination_by_last_id_should_raise_error_when_user_is_none( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + + # Act + Assert + with pytest.raises(ValueError, match="User is required"): + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=None, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + ) + + +def test_pagination_by_last_id_should_forward_without_pin_filter_when_pinned_is_none( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_user = _account(id="user-1") + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id="conv-9", + limit=10, + invoke_from=InvokeFrom.WEB_APP, + pinned=None, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] is None + assert call_kwargs["exclude_ids"] is None + assert call_kwargs["last_id"] == "conv-9" + assert call_kwargs["sort_by"] == "-updated_at" + + +def test_pagination_by_last_id_should_include_only_pinned_ids_when_pinned_true( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-1" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mocker.patch("services.web_conversation_service.EndUser", type("FakeEndUser", (), {})) + session.scalars.return_value.all.return_value = ["conv-1", "conv-2"] + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + pinned=True, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] == ["conv-1", "conv-2"] + assert call_kwargs["exclude_ids"] is None + + +def test_pagination_by_last_id_should_exclude_pinned_ids_when_pinned_false( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-1" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + session.scalars.return_value.all.return_value = ["conv-3"] + mock_pagination = mocker.patch("services.web_conversation_service.ConversationService.pagination_by_last_id") + mock_pagination.return_value = MagicMock() + + # Act + WebConversationService.pagination_by_last_id( + session=session, + app_model=app_model, + user=fake_user, + last_id=None, + limit=20, + invoke_from=InvokeFrom.WEB_APP, + pinned=False, + ) + + # Assert + call_kwargs = mock_pagination.call_args.kwargs + assert call_kwargs["include_ids"] is None + assert call_kwargs["exclude_ids"] == ["conv-3"] + + +def test_pin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: + # Arrange + mock_db = mocker.patch("services.web_conversation_service.db") + mocker.patch("services.web_conversation_service.ConversationService.get_conversation") + + # Act + WebConversationService.pin(app_model, "conv-1", None) + + # Assert + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_pin_should_return_early_when_conversation_is_already_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-1" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = object() + mock_get_conversation = mocker.patch("services.web_conversation_service.ConversationService.get_conversation") + + # Act + WebConversationService.pin(app_model, "conv-1", fake_user) + + # Assert + mock_get_conversation.assert_not_called() + mock_db.session.add.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_pin_should_create_pinned_conversation_when_not_already_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_account_cls = type("FakeAccount", (), {}) + fake_user = cast(Account, fake_account_cls()) + fake_user.id = "account-2" + mocker.patch("services.web_conversation_service.Account", fake_account_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_conversation = SimpleNamespace(id="conv-2") + mock_get_conversation = mocker.patch( + "services.web_conversation_service.ConversationService.get_conversation", + return_value=mock_conversation, + ) + + # Act + WebConversationService.pin(app_model, "conv-2", fake_user) + + # Assert + mock_get_conversation.assert_called_once_with(app_model=app_model, conversation_id="conv-2", user=fake_user) + added_obj = mock_db.session.add.call_args.args[0] + assert added_obj.app_id == "app-1" + assert added_obj.conversation_id == "conv-2" + assert added_obj.created_by_role == "account" + assert added_obj.created_by == "account-2" + mock_db.session.commit.assert_called_once() + + +def test_unpin_should_return_early_when_user_is_none(app_model: App, mocker: MockerFixture) -> None: + # Arrange + mock_db = mocker.patch("services.web_conversation_service.db") + + # Act + WebConversationService.unpin(app_model, "conv-1", None) + + # Assert + mock_db.session.delete.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_unpin_should_return_early_when_conversation_is_not_pinned( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-3" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + WebConversationService.unpin(app_model, "conv-7", fake_user) + + # Assert + mock_db.session.delete.assert_not_called() + mock_db.session.commit.assert_not_called() + + +def test_unpin_should_delete_pinned_conversation_when_exists( + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + fake_end_user_cls = type("FakeEndUser", (), {}) + fake_user = cast(EndUser, fake_end_user_cls()) + fake_user.id = "end-user-4" + mocker.patch("services.web_conversation_service.Account", type("FakeAccount", (), {})) + mocker.patch("services.web_conversation_service.EndUser", fake_end_user_cls) + mock_db = mocker.patch("services.web_conversation_service.db") + pinned_obj = SimpleNamespace(id="pin-1") + mock_db.session.query.return_value.where.return_value.first.return_value = pinned_obj + + # Act + WebConversationService.unpin(app_model, "conv-8", fake_user) + + # Assert + mock_db.session.delete.assert_called_once_with(pinned_obj) + mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/services/test_webapp_auth_service.py b/api/tests/unit_tests/services/test_webapp_auth_service.py new file mode 100644 index 0000000000..262c1f1524 --- /dev/null +++ b/api/tests/unit_tests/services/test_webapp_auth_service.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +from datetime import UTC, datetime +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from werkzeug.exceptions import NotFound, Unauthorized + +from models import Account, AccountStatus +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType + +ACCOUNT_LOOKUP_PATH = "services.webapp_auth_service.AccountService.get_account_by_email_with_case_fallback" +TOKEN_GENERATE_PATH = "services.webapp_auth_service.TokenManager.generate_token" +TOKEN_GET_DATA_PATH = "services.webapp_auth_service.TokenManager.get_token_data" + + +def _account(**kwargs: Any) -> Account: + return cast(Account, SimpleNamespace(**kwargs)) + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.webapp_auth_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +def test_authenticate_should_raise_account_not_found_when_email_does_not_exist(mocker: MockerFixture) -> None: + # Arrange + mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) + + # Act + Assert + with pytest.raises(AccountNotFoundError): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +def test_authenticate_should_raise_account_login_error_when_account_is_banned(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.BANNED, password="hash", password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + Assert + with pytest.raises(AccountLoginError, match="Account is banned"): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +@pytest.mark.parametrize("password_value", [None, "hash"]) +def test_authenticate_should_raise_password_error_when_password_is_invalid( + password_value: str | None, + mocker: MockerFixture, +) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE, password=password_value, password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + mocker.patch("services.webapp_auth_service.compare_password", return_value=False) + + # Act + Assert + with pytest.raises(AccountPasswordError, match="Invalid email or password"): + WebAppAuthService.authenticate("user@example.com", "pwd") + + +def test_authenticate_should_return_account_when_credentials_are_valid(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE, password="hash", password_salt="salt") + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + mocker.patch("services.webapp_auth_service.compare_password", return_value=True) + + # Act + result = WebAppAuthService.authenticate("user@example.com", "pwd") + + # Assert + assert result is account + + +def test_login_should_return_token_from_internal_token_builder(mocker: MockerFixture) -> None: + # Arrange + account = _account(id="a1", email="u@example.com") + mock_get_token = mocker.patch.object(WebAppAuthService, "_get_account_jwt_token", return_value="jwt-token") + + # Act + result = WebAppAuthService.login(account) + + # Assert + assert result == "jwt-token" + mock_get_token.assert_called_once_with(account=account) + + +def test_get_user_through_email_should_return_none_when_account_not_found(mocker: MockerFixture) -> None: + # Arrange + mocker.patch(ACCOUNT_LOOKUP_PATH, return_value=None) + + # Act + result = WebAppAuthService.get_user_through_email("missing@example.com") + + # Assert + assert result is None + + +def test_get_user_through_email_should_raise_unauthorized_when_account_banned(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.BANNED) + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + Assert + with pytest.raises(Unauthorized, match="Account is banned"): + WebAppAuthService.get_user_through_email("user@example.com") + + +def test_get_user_through_email_should_return_account_when_active(mocker: MockerFixture) -> None: + # Arrange + account = SimpleNamespace(status=AccountStatus.ACTIVE) + mocker.patch( + ACCOUNT_LOOKUP_PATH, + return_value=account, + ) + + # Act + result = WebAppAuthService.get_user_through_email("user@example.com") + + # Assert + assert result is account + + +def test_send_email_code_login_email_should_raise_error_when_email_not_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Email must be provided"): + WebAppAuthService.send_email_code_login_email(account=None, email=None) + + +def test_send_email_code_login_email_should_generate_token_and_send_mail_for_account( + mocker: MockerFixture, +) -> None: + # Arrange + account = _account(email="user@example.com") + mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[1, 2, 3, 4, 5, 6]) + mock_generate_token = mocker.patch(TOKEN_GENERATE_PATH, return_value="token-1") + mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") + + # Act + result = WebAppAuthService.send_email_code_login_email(account=account, language="en-US") + + # Assert + assert result == "token-1" + mock_generate_token.assert_called_once() + assert mock_generate_token.call_args.kwargs["additional_data"] == {"code": "123456"} + mock_delay.assert_called_once_with(language="en-US", to="user@example.com", code="123456") + + +def test_send_email_code_login_email_should_send_mail_for_email_without_account( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.secrets.randbelow", side_effect=[0, 0, 0, 0, 0, 0]) + mocker.patch(TOKEN_GENERATE_PATH, return_value="token-2") + mock_delay = mocker.patch("services.webapp_auth_service.send_email_code_login_mail_task.delay") + + # Act + result = WebAppAuthService.send_email_code_login_email(account=None, email="alt@example.com", language="zh-Hans") + + # Assert + assert result == "token-2" + mock_delay.assert_called_once_with(language="zh-Hans", to="alt@example.com", code="000000") + + +def test_get_email_code_login_data_should_delegate_to_token_manager(mocker: MockerFixture) -> None: + # Arrange + mock_get_data = mocker.patch(TOKEN_GET_DATA_PATH, return_value={"code": "123"}) + + # Act + result = WebAppAuthService.get_email_code_login_data("token-abc") + + # Assert + assert result == {"code": "123"} + mock_get_data.assert_called_once_with("token-abc", "email_code_login") + + +def test_revoke_email_code_login_token_should_delegate_to_token_manager(mocker: MockerFixture) -> None: + # Arrange + mock_revoke = mocker.patch("services.webapp_auth_service.TokenManager.revoke_token") + + # Act + WebAppAuthService.revoke_email_code_login_token("token-xyz") + + # Assert + mock_revoke.assert_called_once_with("token-xyz", "email_code_login") + + +def test_create_end_user_should_raise_not_found_when_site_does_not_exist(mock_db: MagicMock) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(NotFound, match="Site not found"): + WebAppAuthService.create_end_user("app-code", "user@example.com") + + +def test_create_end_user_should_raise_not_found_when_app_does_not_exist(mock_db: MagicMock) -> None: + # Arrange + site = SimpleNamespace(app_id="app-1") + app_query = MagicMock() + app_query.where.return_value.first.return_value = None + mock_db.session.query.return_value.where.return_value.first.side_effect = [site, None] + + # Act + Assert + with pytest.raises(NotFound, match="App not found"): + WebAppAuthService.create_end_user("app-code", "user@example.com") + + +def test_create_end_user_should_create_and_commit_end_user_when_data_is_valid(mock_db: MagicMock) -> None: + # Arrange + site = SimpleNamespace(app_id="app-1") + app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1") + mock_db.session.query.return_value.where.return_value.first.side_effect = [site, app_model] + + # Act + result = WebAppAuthService.create_end_user("app-code", "user@example.com") + + # Assert + assert result.tenant_id == "tenant-1" + assert result.app_id == "app-1" + assert result.session_id == "user@example.com" + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_get_account_jwt_token_should_build_payload_and_issue_token(mocker: MockerFixture) -> None: + # Arrange + account = _account(id="a1", email="user@example.com") + mocker.patch("services.webapp_auth_service.dify_config.ACCESS_TOKEN_EXPIRE_MINUTES", 60) + mock_issue = mocker.patch("services.webapp_auth_service.PassportService.issue", return_value="jwt-1") + + # Act + token = WebAppAuthService._get_account_jwt_token(account) + + # Assert + assert token == "jwt-1" + payload = mock_issue.call_args.args[0] + assert payload["user_id"] == "a1" + assert payload["session_id"] == "user@example.com" + assert payload["token_source"] == "webapp_login_token" + assert payload["auth_type"] == "internal" + assert payload["exp"] > int(datetime.now(UTC).timestamp()) + + +@pytest.mark.parametrize( + ("access_mode", "expected"), + [ + ("private", True), + ("private_all", True), + ("public", False), + ], +) +def test_is_app_require_permission_check_should_use_access_mode_when_provided( + access_mode: str, + expected: bool, +) -> None: + # Arrange + # Act + result = WebAppAuthService.is_app_require_permission_check(access_mode=access_mode) + + # Assert + assert result is expected + + +def test_is_app_require_permission_check_should_raise_when_no_identifier_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Either app_code or app_id must be provided"): + WebAppAuthService.is_app_require_permission_check() + + +def test_is_app_require_permission_check_should_raise_when_app_id_cannot_be_determined(mocker: MockerFixture) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value=None) + + # Act + Assert + with pytest.raises(ValueError, match="App ID could not be determined"): + WebAppAuthService.is_app_require_permission_check(app_code="app-code") + + +def test_is_app_require_permission_check_should_return_true_when_enterprise_mode_requires_it( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="private"), + ) + + # Act + result = WebAppAuthService.is_app_require_permission_check(app_code="app-code") + + # Assert + assert result is True + + +def test_is_app_require_permission_check_should_return_false_when_enterprise_settings_do_not_require_it( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="public"), + ) + + # Act + result = WebAppAuthService.is_app_require_permission_check(app_id="app-1") + + # Assert + assert result is False + + +@pytest.mark.parametrize( + ("access_mode", "expected"), + [ + ("public", WebAppAuthType.PUBLIC), + ("private", WebAppAuthType.INTERNAL), + ("private_all", WebAppAuthType.INTERNAL), + ("sso_verified", WebAppAuthType.EXTERNAL), + ], +) +def test_get_app_auth_type_should_map_access_modes_correctly( + access_mode: str, + expected: WebAppAuthType, +) -> None: + # Arrange + # Act + result = WebAppAuthService.get_app_auth_type(access_mode=access_mode) + + # Assert + assert result == expected + + +def test_get_app_auth_type_should_resolve_from_app_code(mocker: MockerFixture) -> None: + # Arrange + mocker.patch("services.webapp_auth_service.AppService.get_app_id_by_code", return_value="app-1") + mocker.patch( + "services.webapp_auth_service.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=SimpleNamespace(access_mode="private_all"), + ) + + # Act + result = WebAppAuthService.get_app_auth_type(app_code="app-code") + + # Assert + assert result == WebAppAuthType.INTERNAL + + +def test_get_app_auth_type_should_raise_when_no_input_provided() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Either app_code or access_mode must be provided"): + WebAppAuthService.get_app_auth_type() + + +def test_get_app_auth_type_should_raise_when_cannot_determine_type_from_invalid_mode() -> None: + # Arrange + # Act + Assert + with pytest.raises(ValueError, match="Could not determine app authentication type"): + WebAppAuthService.get_app_auth_type(access_mode="unknown") diff --git a/api/tests/unit_tests/services/test_workflow_app_service.py b/api/tests/unit_tests/services/test_workflow_app_service.py new file mode 100644 index 0000000000..fa76521f2d --- /dev/null +++ b/api/tests/unit_tests/services/test_workflow_app_service.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import json +import uuid +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from dify_graph.enums import WorkflowExecutionStatus +from models import App, WorkflowAppLog +from models.enums import AppTriggerType, CreatorUserRole +from services.workflow_app_service import LogView, WorkflowAppService + + +@pytest.fixture +def service() -> WorkflowAppService: + # Arrange + return WorkflowAppService() + + +@pytest.fixture +def app_model() -> App: + # Arrange + return cast(App, SimpleNamespace(id="app-1", tenant_id="tenant-1")) + + +def _workflow_app_log(**kwargs: Any) -> WorkflowAppLog: + return cast(WorkflowAppLog, SimpleNamespace(**kwargs)) + + +def test_log_view_details_should_return_wrapped_details_and_proxy_attributes() -> None: + # Arrange + log = _workflow_app_log(id="log-1", status="succeeded") + view = LogView(log=log, details={"trigger_metadata": {"type": "plugin"}}) + + # Act + details = view.details + proxied_status = view.status + + # Assert + assert details == {"trigger_metadata": {"type": "plugin"}} + assert proxied_status == "succeeded" + + +def test_get_paginate_workflow_app_logs_should_return_paginated_summary_when_detail_false( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + log_1 = SimpleNamespace(id="log-1") + log_2 = SimpleNamespace(id="log-2") + session.scalar.return_value = 3 + session.scalars.return_value.all.return_value = [log_1, log_2] + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + page=1, + limit=2, + detail=False, + ) + + # Assert + assert result["page"] == 1 + assert result["limit"] == 2 + assert result["total"] == 3 + assert result["has_more"] is True + assert len(result["data"]) == 2 + assert isinstance(result["data"][0], LogView) + assert result["data"][0].details is None + + +def test_get_paginate_workflow_app_logs_should_return_detailed_rows_when_detail_true( + service: WorkflowAppService, + app_model: App, + mocker: MockerFixture, +) -> None: + # Arrange + session = MagicMock() + session.scalar.side_effect = [1] + log_1 = SimpleNamespace(id="log-1") + session.execute.return_value.all.return_value = [(log_1, '{"type":"trigger_plugin"}')] + mock_handle = mocker.patch.object( + service, + "handle_trigger_metadata", + return_value={"type": "trigger_plugin", "icon": "url"}, + ) + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + keyword="run-1", + status=WorkflowExecutionStatus.SUCCEEDED, + created_at_before=None, + created_at_after=None, + page=1, + limit=20, + detail=True, + ) + + # Assert + assert result["total"] == 1 + assert len(result["data"]) == 1 + assert result["data"][0].details == {"trigger_metadata": {"type": "trigger_plugin", "icon": "url"}} + mock_handle.assert_called_once() + + +def test_get_paginate_workflow_app_logs_should_raise_when_account_filter_email_not_found( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + session.scalar.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="Account not found: account@example.com"): + service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + created_by_account="account@example.com", + ) + + +def test_get_paginate_workflow_app_logs_should_filter_by_account_when_account_exists( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + session.scalar.side_effect = [SimpleNamespace(id="account-1"), 0] + session.scalars.return_value.all.return_value = [] + + # Act + result = service.get_paginate_workflow_app_logs( + session=session, + app_model=app_model, + created_by_account="account@example.com", + ) + + # Assert + assert result["total"] == 0 + assert result["data"] == [] + + +def test_get_paginate_workflow_archive_logs_should_return_paginated_archive_items( + service: WorkflowAppService, + app_model: App, +) -> None: + # Arrange + session = MagicMock() + log_account = SimpleNamespace( + id="log-1", + created_by="acc-1", + created_by_role=CreatorUserRole.ACCOUNT, + workflow_run_summary={"run": "1"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-01", + ) + log_end_user = SimpleNamespace( + id="log-2", + created_by="end-1", + created_by_role=CreatorUserRole.END_USER, + workflow_run_summary={"run": "2"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-02", + ) + log_unknown = SimpleNamespace( + id="log-3", + created_by="other", + created_by_role="system", + workflow_run_summary={"run": "3"}, + trigger_metadata='{"type":"trigger-webhook"}', + log_created_at="2026-01-03", + ) + session.scalar.return_value = 3 + session.scalars.side_effect = [ + SimpleNamespace(all=lambda: [log_account, log_end_user, log_unknown]), + SimpleNamespace(all=lambda: [SimpleNamespace(id="acc-1", email="a@example.com")]), + SimpleNamespace(all=lambda: [SimpleNamespace(id="end-1", session_id="session-1")]), + ] + + # Act + result = service.get_paginate_workflow_archive_logs( + session=session, + app_model=app_model, + page=1, + limit=20, + ) + + # Assert + assert result["total"] == 3 + assert len(result["data"]) == 3 + assert result["data"][0]["created_by_account"].id == "acc-1" + assert result["data"][1]["created_by_end_user"].id == "end-1" + assert result["data"][2]["created_by_account"] is None + assert result["data"][2]["created_by_end_user"] is None + + +def test_handle_trigger_metadata_should_return_empty_dict_when_metadata_missing( + service: WorkflowAppService, +) -> None: + # Arrange + # Act + result = service.handle_trigger_metadata("tenant-1", None) + + # Assert + assert result == {} + + +def test_handle_trigger_metadata_should_enrich_plugin_icons_for_trigger_plugin( + service: WorkflowAppService, + mocker: MockerFixture, +) -> None: + # Arrange + meta = { + "type": AppTriggerType.TRIGGER_PLUGIN.value, + "icon_filename": "light.png", + "icon_dark_filename": "dark.png", + } + mock_icon = mocker.patch( + "services.workflow_app_service.PluginService.get_plugin_icon_url", + side_effect=["https://cdn/light.png", "https://cdn/dark.png"], + ) + + # Act + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + # Assert + assert result["icon"] == "https://cdn/light.png" + assert result["icon_dark"] == "https://cdn/dark.png" + assert mock_icon.call_count == 2 + + +def test_handle_trigger_metadata_should_return_non_plugin_metadata_without_icon_lookup( + service: WorkflowAppService, + mocker: MockerFixture, +) -> None: + # Arrange + meta = {"type": AppTriggerType.TRIGGER_WEBHOOK.value} + mock_icon = mocker.patch("services.workflow_app_service.PluginService.get_plugin_icon_url") + + # Act + result = service.handle_trigger_metadata("tenant-1", json.dumps(meta)) + + # Assert + assert result["type"] == AppTriggerType.TRIGGER_WEBHOOK.value + mock_icon.assert_not_called() + + +@pytest.mark.parametrize( + ("value", "expected"), + [ + (None, None), + ("", None), + ('{"k":"v"}', {"k": "v"}), + ("not-json", None), + ({"raw": True}, {"raw": True}), + ], +) +def test_safe_json_loads_should_handle_various_inputs( + value: object, + expected: object, + service: WorkflowAppService, +) -> None: + # Arrange + # Act + result = service._safe_json_loads(value) + + # Assert + assert result == expected + + +def test_safe_parse_uuid_should_return_none_for_short_or_invalid_values(service: WorkflowAppService) -> None: + # Arrange + # Act + short_result = service._safe_parse_uuid("short") + invalid_result = service._safe_parse_uuid("x" * 40) + + # Assert + assert short_result is None + assert invalid_result is None + + +def test_safe_parse_uuid_should_return_uuid_for_valid_uuid_string(service: WorkflowAppService) -> None: + # Arrange + raw_uuid = str(uuid.uuid4()) + + # Act + result = service._safe_parse_uuid(raw_uuid) + + # Assert + assert result is not None + assert str(result) == raw_uuid diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index 753cff8697..d26c2f674f 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -10,18 +10,36 @@ This test suite covers: """ import json +import uuid +from typing import Any, cast from unittest.mock import MagicMock, patch import pytest -from dify_graph.enums import BuiltinNodeTypes +from dify_graph.entities import WorkflowNodeExecution +from dify_graph.enums import ( + BuiltinNodeTypes, + ErrorStrategy, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from dify_graph.errors import WorkflowNodeRunFailedError +from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from dify_graph.node_events import NodeRunResult from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from dify_graph.variables.input_entities import VariableEntityType from libs.datetime_utils import naive_utc_now +from models.human_input import RecipientType from models.model import App, AppMode from models.workflow import Workflow, WorkflowType from services.errors.app import IsDraftWorkflowError, TriggerNodeLimitExceededError, WorkflowHashNotEqualError from services.errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError -from services.workflow_service import WorkflowService +from services.workflow_service import ( + WorkflowService, + _rebuild_file_for_user_inputs_in_start_node, + _rebuild_single_file, + _setup_variable_pool, +) class TestWorkflowAssociatedDataFactory: @@ -1309,3 +1327,1416 @@ class TestWorkflowService: with pytest.raises(ValueError, match="not supported convert to workflow"): workflow_service.convert_to_workflow(app, account, args) + + +# =========================================================================== +# TestWorkflowServiceCredentialValidation +# Tests for _validate_workflow_credentials and related private helpers +# =========================================================================== + + +class TestWorkflowServiceCredentialValidation: + """ + Tests for the private credential-validation helpers on WorkflowService. + + These helpers gate `publish_workflow` when `PluginManager` is enabled. + Each test focuses on a distinct branch inside `_validate_workflow_credentials`, + `_validate_llm_model_config`, `_check_default_tool_credential`, and the + load-balancing path. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + @staticmethod + def _make_workflow(nodes: list[dict]) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.tenant_id = "tenant-1" + wf.app_id = "app-1" + wf.graph_dict = {"nodes": nodes} + return wf + + # --- _validate_workflow_credentials: tool node (with credential_id) --- + + def test_validate_workflow_credentials_should_check_tool_credential_when_credential_id_present( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + "credential_id": "cred-123", + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check: + # Should not raise; mock allows the call + service._validate_workflow_credentials(workflow) + mock_check.assert_called_once() + + def test_validate_workflow_credentials_should_check_default_credential_when_no_credential_id( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "tool-node", + "data": { + "type": "tool", + "provider_id": "my-provider", + # No credential_id — should fall back to default + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + + # Assert + mock_default.assert_called_once_with("tenant-1", "my-provider") + + def test_validate_workflow_credentials_should_skip_tool_node_without_provider( + self, service: WorkflowService + ) -> None: + """Tool nodes without a provider_id should be silently skipped.""" + # Arrange + nodes = [{"id": "tool-node", "data": {"type": "tool"}}] + workflow = self._make_workflow(nodes) + + # Act + Assert (no error raised) + with patch.object(service, "_check_default_tool_credential") as mock_default: + service._validate_workflow_credentials(workflow) + mock_default.assert_not_called() + + def test_validate_workflow_credentials_should_validate_llm_node_with_model_config( + self, service: WorkflowService + ) -> None: + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_raise_for_llm_node_missing_model( + self, service: WorkflowService + ) -> None: + """LLM nodes without provider AND name should raise ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": {"type": "llm", "model": {"provider": "openai"}}, # name missing + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with pytest.raises(ValueError, match="Missing provider or model configuration"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_wrap_unexpected_exception_in_value_error( + self, service: WorkflowService + ) -> None: + """Non-ValueError exceptions from validation must be re-raised as ValueError.""" + # Arrange + nodes = [ + { + "id": "llm-node", + "data": { + "type": "llm", + "model": {"provider": "openai", "name": "gpt-4"}, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + Assert + with patch.object(service, "_validate_llm_model_config", side_effect=RuntimeError("boom")): + with pytest.raises(ValueError, match="boom"): + service._validate_workflow_credentials(workflow) + + def test_validate_workflow_credentials_should_validate_agent_node_model(self, service: WorkflowService) -> None: + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {"provider": "openai", "model": "gpt-4"}}, + "tools": {"value": []}, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch.object(service, "_validate_llm_model_config") as mock_llm, + patch.object(service, "_validate_load_balancing_credentials"), + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_llm.assert_called_once_with("tenant-1", "openai", "gpt-4") + + def test_validate_workflow_credentials_should_validate_agent_tools(self, service: WorkflowService) -> None: + """Each agent tool with a provider should be checked for credential compliance.""" + # Arrange + nodes = [ + { + "id": "agent-node", + "data": { + "type": "agent", + "agent_parameters": { + "model": {"value": {}}, # no model config + "tools": { + "value": [ + {"provider_name": "provider-a", "credential_id": "cred-a"}, + {"provider_name": "provider-b"}, # uses default + ] + }, + }, + }, + } + ] + workflow = self._make_workflow(nodes) + + # Act + with ( + patch("core.helper.credential_utils.check_credential_policy_compliance") as mock_check, + patch.object(service, "_check_default_tool_credential") as mock_default, + ): + service._validate_workflow_credentials(workflow) + + # Assert + mock_check.assert_called_once() # provider-a has credential_id + mock_default.assert_called_once_with("tenant-1", "provider-b") + + # --- _validate_llm_model_config --- + + def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: + """If ModelManager raises any exception it must be wrapped into ValueError.""" + # Arrange + with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + def test_validate_llm_model_config_success(self, service: WorkflowService) -> None: + """Test success path with ProviderManager and Model entities.""" + mock_model = MagicMock() + mock_model.model = "gpt-4" + mock_model.provider.provider = "openai" + + mock_configs = MagicMock() + mock_configs.get_models.return_value = [mock_model] + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # Assert + mock_model.raise_for_status.assert_called_once() + + def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: + """Test ValueError when model is not found in provider configurations.""" + mock_configs = MagicMock() + mock_configs.get_models.return_value = [] # No models + + with ( + patch("core.model_manager.ModelManager.get_model_instance"), + patch("core.provider_manager.ProviderManager") as mock_pm_cls, + ): + mock_pm_cls.return_value.get_configurations.return_value = mock_configs + + # Act + Assert + with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): + service._validate_llm_model_config("tenant-1", "openai", "gpt-4") + + # --- _check_default_tool_credential --- + + def test_check_default_tool_credential_should_silently_pass_when_no_provider_found( + self, service: WorkflowService + ) -> None: + """Missing BuiltinToolProvider → plugin requires no credentials → no error.""" + # Arrange + with patch("services.workflow_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + # Act + Assert (should NOT raise) + service._check_default_tool_credential("tenant-1", "some-provider") + + def test_check_default_tool_credential_should_raise_when_compliance_fails(self, service: WorkflowService) -> None: + # Arrange + mock_provider = MagicMock() + mock_provider.id = "builtin-cred-id" + with ( + patch("services.workflow_service.db") as mock_db, + patch("core.helper.credential_utils.check_credential_policy_compliance", side_effect=Exception("denied")), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( + mock_provider + ) + + # Act + Assert + with pytest.raises(ValueError, match="Failed to validate default credential"): + service._check_default_tool_credential("tenant-1", "some-provider") + + # --- _is_load_balancing_enabled --- + + def test_is_load_balancing_enabled_should_return_false_when_provider_not_found( + self, service: WorkflowService + ) -> None: + # Arrange + with patch("services.workflow_service.db"): + service_instance = WorkflowService() + + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_configs = MagicMock() + mock_configs.get.return_value = None # provider not found + mock_get_configs.return_value = mock_configs + + # Act + result = service_instance._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + def test_is_load_balancing_enabled_should_return_true_when_setting_enabled(self, service: WorkflowService) -> None: + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations") as mock_get_configs: + mock_provider_config = MagicMock() + mock_provider_model_setting = MagicMock() + mock_provider_model_setting.load_balancing_enabled = True + mock_provider_config.get_provider_model_setting.return_value = mock_provider_model_setting + + mock_configs = MagicMock() + mock_configs.get.return_value = mock_provider_config + mock_get_configs.return_value = mock_configs + + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is True + + def test_is_load_balancing_enabled_should_return_false_on_exception(self, service: WorkflowService) -> None: + """Any exception should be swallowed and return False.""" + # Arrange + with patch("core.provider_manager.ProviderManager.get_configurations", side_effect=RuntimeError("db down")): + # Act + result = service._is_load_balancing_enabled("tenant-1", "openai", "gpt-4") + + # Assert + assert result is False + + # --- _get_load_balancing_configs --- + + def test_get_load_balancing_configs_should_return_empty_list_on_exception(self, service: WorkflowService) -> None: + """Any exception during LB config retrieval should return an empty list.""" + # Arrange + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=RuntimeError("fail"), + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert + assert result == [] + + def test_get_load_balancing_configs_should_merge_predefined_and_custom(self, service: WorkflowService) -> None: + # Arrange + predefined = [{"credential_id": "cred-a"}, {"credential_id": None}] + custom = [{"credential_id": "cred-b"}] + with patch( + "services.model_load_balancing_service.ModelLoadBalancingService.get_load_balancing_configs", + side_effect=[ + (None, predefined), # first call: predefined-model + (None, custom), # second call: custom-model + ], + ): + # Act + result = service._get_load_balancing_configs("tenant-1", "openai", "gpt-4") + + # Assert — only entries with a credential_id should be returned + assert len(result) == 2 + assert all(c["credential_id"] for c in result) + + # --- _validate_load_balancing_credentials --- + + def test_validate_load_balancing_credentials_should_skip_when_no_model_config( + self, service: WorkflowService + ) -> None: + """Missing provider or model in node_data should be a no-op.""" + # Arrange + workflow = self._make_workflow([]) + node_data: dict = {} # no model key + + # Act + Assert (no error expected) + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_skip_when_lb_not_enabled( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + + # Act + Assert (no error expected) + with patch.object(service, "_is_load_balancing_enabled", return_value=False): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + def test_validate_load_balancing_credentials_should_raise_when_compliance_fails( + self, service: WorkflowService + ) -> None: + # Arrange + workflow = self._make_workflow([]) + node_data = {"model": {"provider": "openai", "name": "gpt-4"}} + lb_configs = [{"credential_id": "cred-lb-1"}] + + # Act + Assert + with ( + patch.object(service, "_is_load_balancing_enabled", return_value=True), + patch.object(service, "_get_load_balancing_configs", return_value=lb_configs), + patch( + "core.helper.credential_utils.check_credential_policy_compliance", + side_effect=Exception("policy violation"), + ), + ): + with pytest.raises(ValueError, match="Invalid load balancing credentials"): + service._validate_load_balancing_credentials(workflow, node_data, "node-1") + + +# =========================================================================== +# TestWorkflowServiceExecutionHelpers +# Tests for _apply_error_strategy, _populate_execution_result, _execute_node_safely +# =========================================================================== + + +class TestWorkflowServiceExecutionHelpers: + """ + Tests for the private execution-result handling methods: + _apply_error_strategy, _populate_execution_result, _execute_node_safely. + """ + + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + # --- _apply_error_strategy --- + + def test_apply_error_strategy_should_return_exception_status_noderunresult(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="something went wrong", + error_type="SomeError", + inputs={"x": 1}, + outputs={}, + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + assert result.error == "something went wrong" + assert result.metadata[WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY] == ErrorStrategy.FAIL_BRANCH + + def test_apply_error_strategy_should_include_default_values_for_default_value_strategy( + self, service: WorkflowService + ) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.DEFAULT_VALUE + node.default_value_dict = {"output_key": "fallback"} + original = NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error="err", + ) + + # Act + result = service._apply_error_strategy(node, original) + + # Assert + assert result.outputs.get("output_key") == "fallback" + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + + # --- _populate_execution_result --- + + def test_populate_execution_result_should_set_succeeded_fields_when_run_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"q": "hello"}, + process_data={"steps": 3}, + outputs={"answer": "hi"}, + metadata={WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 10}, + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert node_execution.outputs == {"answer": "hi"} + assert node_execution.error is None # SUCCEEDED status doesn't set error + + def test_populate_execution_result_should_set_failed_status_and_error_when_not_succeeded( + self, service: WorkflowService + ) -> None: + # Arrange + node_execution = MagicMock(error=None) + + # Act + service._populate_execution_result(node_execution, None, False, "catastrophic failure") + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.FAILED + assert node_execution.error == "catastrophic failure" + + def test_populate_execution_result_should_set_error_field_for_exception_status( + self, service: WorkflowService + ) -> None: + """A succeeded=True result with EXCEPTION status should still populate the error field.""" + # Arrange + node_execution = MagicMock() + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.EXCEPTION, + error="constraint violated", + ) + + # Act + with patch("services.workflow_service.WorkflowEntry.handle_special_values", side_effect=lambda x: x): + service._populate_execution_result(node_execution, node_run_result, True, None) + + # Assert + assert node_execution.status == WorkflowNodeExecutionStatus.EXCEPTION + assert node_execution.error == "constraint violated" + + # --- _execute_node_safely --- + + def test_execute_node_safely_should_return_succeeded_result_on_happy_path(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED + node_run_result.error = None + + succeeded_event = MagicMock(spec=NodeRunSucceededEvent) + succeeded_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield succeeded_event + + return node, _gen() + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert run_succeeded is True + assert error is None + + def test_execute_node_safely_should_return_failed_result_on_failed_event(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = None + node_run_result = MagicMock() + node_run_result.status = WorkflowNodeExecutionStatus.FAILED + node_run_result.error = "node exploded" + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = node_run_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, _, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert run_succeeded is False + assert error == "node exploded" + + def test_execute_node_safely_should_handle_workflow_node_run_failed_error(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + exc = WorkflowNodeRunFailedError(node, "runtime failure") + + def invoke_fn(): + raise exc + + # Act + out_node, out_result, run_succeeded, error = service._execute_node_safely(invoke_fn) + + # Assert + assert out_node is node + assert out_result is None + assert run_succeeded is False + assert error == "runtime failure" + + def test_execute_node_safely_should_raise_when_no_result_event(self, service: WorkflowService) -> None: + """If the generator produces no NodeRunSucceededEvent/NodeRunFailedEvent, ValueError is expected.""" + # Arrange + node = MagicMock() + node.error_strategy = None + + def invoke_fn(): + def _gen(): + yield from [] + + return node, _gen() + + # Act + Assert + with pytest.raises(ValueError, match="no result returned"): + service._execute_node_safely(invoke_fn) + + # --- _apply_error_strategy with FAIL_BRANCH strategy --- + + def test_execute_node_safely_should_apply_error_strategy_on_failed_status(self, service: WorkflowService) -> None: + # Arrange + node = MagicMock() + node.error_strategy = ErrorStrategy.FAIL_BRANCH + node.default_value_dict = {} + + original_result = MagicMock() + original_result.status = WorkflowNodeExecutionStatus.FAILED + original_result.error = "oops" + original_result.error_type = "ValueError" + original_result.inputs = {} + + failed_event = MagicMock(spec=NodeRunFailedEvent) + failed_event.node_run_result = original_result + + def invoke_fn(): + def _gen(): + yield failed_event + + return node, _gen() + + # Act + _, result, run_succeeded, _ = service._execute_node_safely(invoke_fn) + + # Assert — after applying error strategy status becomes EXCEPTION + assert result is not None + assert result.status == WorkflowNodeExecutionStatus.EXCEPTION + # run_succeeded should be True because EXCEPTION is in the succeeded set + assert run_succeeded is True + + +# =========================================================================== +# TestWorkflowServiceGetNodeLastRun +# Tests for get_node_last_run delegation to repository +# =========================================================================== + + +class TestWorkflowServiceGetNodeLastRun: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_node_last_run_should_delegate_to_repository(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "tenant-1" + app.id = "app-1" + workflow = MagicMock(spec=Workflow) + workflow.id = "wf-1" + expected = MagicMock() + + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = expected + + # Act + result = service.get_node_last_run(app, workflow, "node-42") + + # Assert + assert result is expected + service._node_execution_service_repo.get_node_last_execution.assert_called_once_with( + tenant_id="tenant-1", + app_id="app-1", + workflow_id="wf-1", + node_id="node-42", + ) + + def test_get_node_last_run_should_return_none_when_repository_returns_none(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.tenant_id = "t" + app.id = "a" + workflow = MagicMock(spec=Workflow) + workflow.id = "w" + service._node_execution_service_repo = MagicMock() + service._node_execution_service_repo.get_node_last_execution.return_value = None + + # Act + result = service.get_node_last_run(app, workflow, "node-x") + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceModuleLevelHelpers +# Tests for module-level helper functions exported from workflow_service +# =========================================================================== + + +class TestSetupVariablePool: + """ + Tests for the module-level `_setup_variable_pool` function. + This helper initialises the VariablePool used for single-step workflow execution. + """ + + def _make_workflow(self, workflow_type: str = WorkflowType.WORKFLOW.value) -> MagicMock: + wf = MagicMock(spec=Workflow) + wf.app_id = "app-1" + wf.id = "wf-1" + wf.type = workflow_type + wf.environment_variables = [] + return wf + + def test_setup_variable_pool_should_use_full_system_variables_for_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="hello", + files=[], + user_id="u-1", + user_inputs={"k": "v"}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — VariablePool should be called with a SystemVariable (non-default) + MockPool.assert_called_once() + call_kwargs = MockPool.call_args.kwargs + assert call_kwargs["user_inputs"] == {"k": "v"} + + def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( + self, + ) -> None: + # Arrange + workflow = self._make_workflow() + + # Act + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.SystemVariable.default") as mock_default, + ): + _setup_variable_pool( + query="", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.LLM, # not a start/trigger node + conversation_id="conv-1", + conversation_variables=[], + ) + + # Assert — SystemVariable.default() should be used for non-start nodes + mock_default.assert_called_once() + MockPool.assert_called_once() + + def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( + self, + ) -> None: + """For ADVANCED_CHAT workflows on a START node, query/conversation_id/dialogue_count should be set.""" + from models.workflow import WorkflowType + + # Arrange + workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) + + # Act + with patch("services.workflow_service.VariablePool") as MockPool: + _setup_variable_pool( + query="what is AI?", + files=[], + user_id="u-1", + user_inputs={}, + workflow=workflow, + node_type=BuiltinNodeTypes.START, + conversation_id="conv-abc", + conversation_variables=[], + ) + + # Assert — we just verify VariablePool was called (chatflow path executed) + MockPool.assert_called_once() + + +class TestRebuildSingleFile: + """ + Tests for the module-level `_rebuild_single_file` function. + Ensures correct delegation to `build_from_mapping` / `build_from_mappings`. + """ + + def test_rebuild_single_file_should_call_build_from_mapping_for_file_type( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = {"url": "https://example.com/file.pdf", "type": "document"} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE) + + # Assert + assert result is mock_file + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_value_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for file object"): + _rebuild_single_file("tenant-1", "not-a-dict", VariableEntityType.FILE) + + def test_rebuild_single_file_should_call_build_from_mappings_for_file_list( + self, + ) -> None: + # Arrange + tenant_id = "tenant-1" + value = [{"url": "https://example.com/a.pdf"}, {"url": "https://example.com/b.pdf"}] + mock_files = [MagicMock(), MagicMock()] + + # Act + with patch("services.workflow_service.build_from_mappings", return_value=mock_files) as mock_build: + result = _rebuild_single_file(tenant_id, value, VariableEntityType.FILE_LIST) + + # Assert + assert result is mock_files + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + + def test_rebuild_single_file_should_raise_when_file_list_value_not_list( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected list for file list object"): + _rebuild_single_file("tenant-1", "not-a-list", VariableEntityType.FILE_LIST) + + def test_rebuild_single_file_should_return_empty_list_for_empty_file_list( + self, + ) -> None: + # Arrange + Act + result = _rebuild_single_file("tenant-1", [], VariableEntityType.FILE_LIST) + + # Assert + assert result == [] + + def test_rebuild_single_file_should_raise_when_first_element_not_dict( + self, + ) -> None: + # Arrange + Act + Assert + with pytest.raises(ValueError, match="expected dict for first element"): + _rebuild_single_file("tenant-1", ["not-a-dict"], VariableEntityType.FILE_LIST) + + +class TestRebuildFileForUserInputsInStartNode: + """ + Tests for the module-level `_rebuild_file_for_user_inputs_in_start_node` function. + """ + + def _make_start_node_data(self, variables: list) -> MagicMock: + start_data = MagicMock() + start_data.variables = variables + return start_data + + def _make_variable(self, name: str, var_type: VariableEntityType) -> MagicMock: + var = MagicMock() + var.variable = name + var.type = var_type + return var + + def test_rebuild_should_pass_through_non_file_variables( + self, + ) -> None: + # Arrange + text_var = self._make_variable("query", VariableEntityType.TEXT_INPUT) + start_data = self._make_start_node_data([text_var]) + user_inputs = {"query": "hello world"} + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — non-file inputs are untouched + assert result["query"] == "hello world" + + def test_rebuild_should_rebuild_file_variable( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + file_value = {"url": "https://example.com/file.pdf"} + user_inputs = {"attachment": file_value} + mock_file = MagicMock() + + # Act + with patch("services.workflow_service.build_from_mapping", return_value=mock_file): + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — the dict value should be replaced by the rebuilt File object + assert result["attachment"] is mock_file + + def test_rebuild_should_skip_variable_not_in_inputs( + self, + ) -> None: + # Arrange + file_var = self._make_variable("attachment", VariableEntityType.FILE) + start_data = self._make_start_node_data([file_var]) + user_inputs: dict = {} # attachment not provided + + # Act + result = _rebuild_file_for_user_inputs_in_start_node( + tenant_id="tenant-1", + start_node_data=start_data, + user_inputs=user_inputs, + ) + + # Assert — no key should be added for missing inputs + assert "attachment" not in result + + +class TestWorkflowServiceResolveDeliveryMethod: + """ + Tests for the static helper `_resolve_human_input_delivery_method`. + """ + + def _make_method(self, method_id) -> MagicMock: + m = MagicMock() + m.id = method_id + return m + + def test_resolve_delivery_method_should_return_method_when_id_matches(self) -> None: + # Arrange + method_a = self._make_method("method-1") + method_b = self._make_method("method-2") + node_data = MagicMock() + node_data.delivery_methods = [method_a, method_b] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-2" + ) + + # Assert + assert result is method_b + + def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: + # Arrange + method_a = self._make_method("method-1") + node_data = MagicMock() + node_data.delivery_methods = [method_a] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="does-not-exist" + ) + + # Assert + assert result is None + + def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: + # Arrange + node_data = MagicMock() + node_data.delivery_methods = [] + + # Act + result = WorkflowService._resolve_human_input_delivery_method( + node_data=node_data, delivery_method_id="method-1" + ) + + # Assert + assert result is None + + +# =========================================================================== +# TestWorkflowServiceDraftExecution +# Tests for run_draft_workflow_node +# =========================================================================== + + +class TestWorkflowServiceDraftExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_draft_workflow_node_should_execute_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + app.id = "app-1" + app.tenant_id = "tenant-1" + account = MagicMock() + account.id = "user-1" + + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.id = "wf-1" + draft_workflow.tenant_id = "tenant-1" + draft_workflow.app_id = "app-1" + draft_workflow.graph_dict = {"nodes": []} + + node_id = "start-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.START)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + # Mocking complex dependencies + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.StartNodeData") as mock_start_data, + patch( + "services.workflow_service._rebuild_file_for_user_inputs_in_start_node", + side_effect=lambda **kwargs: kwargs["user_inputs"], + ), + patch("services.workflow_service._setup_variable_pool"), + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory") as mock_repo_factory, + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.START + mock_node.title = "Start Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="start-node", + node_type=BuiltinNodeTypes.START, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + mock_repo = MagicMock() + mock_repo_factory.create_workflow_node_execution_repository.return_value = mock_repo + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "start" + mock_execution_record.node_id = "start-node" + mock_execution_record.load_full_outputs.return_value = {} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + result = service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={"key": "val"}, + query="hi", + files=[], + ) + + # Assert + assert result is not None + mock_run.assert_called_once() + mock_repo.save.assert_called_once() + mock_saver_cls.return_value.save.assert_called_once() + + def test_run_draft_workflow_node_should_execute_non_start_node_successfully(self, service: WorkflowService) -> None: + # Arrange + app = MagicMock(spec=App) + account = MagicMock() + draft_workflow = MagicMock(spec=Workflow) + draft_workflow.graph_dict = {"nodes": []} + node_id = "llm-node" + node_config = {"id": node_id, "data": MagicMock(type=BuiltinNodeTypes.LLM)} + draft_workflow.get_node_config_by_id.return_value = node_config + draft_workflow.get_enclosing_node_type_and_id.return_value = None + service.get_draft_workflow = MagicMock(return_value=draft_workflow) + + node_execution = MagicMock(spec=WorkflowNodeExecution) + node_execution.id = "exec-1" + node_execution.process_data = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, + patch("services.workflow_service.DifyCoreRepositoryFactory"), + patch("services.workflow_service.DraftVariableSaver"), + patch("services.workflow_service.storage"), + ): + mock_node = MagicMock() + mock_node.node_type = BuiltinNodeTypes.LLM + mock_node.title = "LLM Node" + mock_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"result": "ok"} + ) + mock_event = NodeRunSucceededEvent( + id=str(uuid.uuid4()), + node_id="llm-node", + node_type=BuiltinNodeTypes.LLM, + node_run_result=mock_run_result, + start_at=naive_utc_now(), + ) + mock_run.return_value = (mock_node, [mock_event]) + + service._node_execution_service_repo = MagicMock() + mock_execution_record = MagicMock() + mock_execution_record.node_type = "llm" + mock_execution_record.node_id = "llm-node" + mock_execution_record.load_full_outputs.return_value = {"answer": "hello"} + service._node_execution_service_repo.get_execution_by_id.return_value = mock_execution_record + + # Act + service.run_draft_workflow_node( + app_model=app, + draft_workflow=draft_workflow, + account=account, + node_id=node_id, + user_inputs={}, + query="", + files=None, + ) + + # Assert + # For non-start nodes, VariablePool should be initialized with environment_variables + mock_pool_cls.assert_called_once() + args, kwargs = mock_pool_cls.call_args + assert "environment_variables" in kwargs + + +# =========================================================================== +# TestWorkflowServiceHumanInputOperations +# Tests for Human Input related methods +# =========================================================================== + + +class TestWorkflowServiceHumanInputOperations: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_get_human_input_form_preview_should_raise_if_workflow_not_init(self, service: WorkflowService) -> None: + service.get_draft_workflow = MagicMock(return_value=None) + with pytest.raises(ValueError, match="Workflow not initialized"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_should_raise_if_wrong_node_type(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "llm"}} + service.get_draft_workflow = MagicMock(return_value=draft) + with patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.LLM): + with pytest.raises(ValueError, match="Node type must be human-input"): + service.get_human_input_form_preview(app_model=MagicMock(), account=MagicMock(), node_id="node-1") + + def test_get_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = { + "id": "node-1", + "data": MagicMock(type=BuiltinNodeTypes.HUMAN_INPUT), + } + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.render_form_content_before_submission.return_value = "rendered" + mock_node.resolve_default_values.return_value = {"def": 1} + mock_node.title = "Form Title" + mock_node.node_data = MagicMock() + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.HumanInputRequired") as mock_required_cls, + ): + service.get_human_input_form_preview(app_model=app_model, account=account, node_id="node-1") + mock_node.render_form_content_before_submission.assert_called_once() + mock_required_cls.return_value.model_dump.assert_called_once() + + def test_submit_human_input_form_preview_success(self, service: WorkflowService) -> None: + app_model = MagicMock(spec=App) + app_model.id = "app-1" + app_model.tenant_id = "tenant-1" + + account = MagicMock() + account.id = "user-1" + + draft = MagicMock() + draft.id = "wf-1" + draft.tenant_id = "tenant-1" + draft.app_id = "app-1" + draft.graph_dict = {"nodes": []} + draft.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + mock_node = MagicMock() + mock_node.node_data = MagicMock() + mock_node.node_data.outputs_field_names.return_value = ["field1"] + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch.object(service, "_build_human_input_variable_pool"), + patch("services.workflow_service.HumanInputNode", return_value=mock_node), + patch("services.workflow_service.validate_human_input_submission"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.DraftVariableSaver") as mock_saver_cls, + ): + result = service.submit_human_input_form_preview( + app_model=app_model, account=account, node_id="node-1", form_inputs={"field1": "val1"}, action="submit" + ) + assert result["__action_id"] == "submit" + mock_saver_cls.return_value.save.assert_called_once() + + def test_test_human_input_delivery_success(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, + patch("services.workflow_service.apply_debug_email_recipient"), + patch.object(service, "_build_human_input_variable_pool"), + patch.object(service, "_build_human_input_node"), + patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), + patch("services.workflow_service.HumanInputDeliveryTestService") as mock_test_srv, + ): + mock_resolve.return_value = MagicMock() + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="method-1" + ) + mock_test_srv.return_value.send_test.assert_called_once() + + def test_test_human_input_delivery_failure_cases(self, service: WorkflowService) -> None: + draft = MagicMock() + draft.get_node_config_by_id.return_value = {"data": {"type": "human-input"}} + service.get_draft_workflow = MagicMock(return_value=draft) + + with ( + patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), + patch("services.workflow_service.HumanInputNodeData.model_validate"), + patch.object(service, "_resolve_human_input_delivery_method", return_value=None), + ): + with pytest.raises(ValueError, match="Delivery method not found"): + service.test_human_input_delivery( + app_model=MagicMock(), account=MagicMock(), node_id="node-1", delivery_method_id="none" + ) + + def test_load_email_recipients_parsing_failure(self, service: WorkflowService) -> None: + # Arrange + mock_recipient = MagicMock() + mock_recipient.recipient_payload = "invalid json" + mock_recipient.recipient_type = RecipientType.EMAIL_MEMBER + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.Session") as mock_session_cls, + patch("services.workflow_service.select"), + patch("services.workflow_service.json.loads", side_effect=ValueError("bad json")), + ): + mock_session = mock_session_cls.return_value.__enter__.return_value + # sqlalchemy assertions check for .bind + mock_session.bind = MagicMock() # removed spec=Engine to avoid import issues for now + mock_session.scalars.return_value.all.return_value = [mock_recipient] + + # Act + # _load_email_recipients(form_id: str) is a static method + result = WorkflowService._load_email_recipients("form-1") + + # Assert + assert result == [] # Should fall back to empty list on parsing error + + def test_build_human_input_variable_pool(self, service: WorkflowService) -> None: + workflow = MagicMock() + workflow.environment_variables = [] + workflow.graph_dict = {} + + with ( + patch("services.workflow_service.db"), + patch("services.workflow_service.Session"), + patch("services.workflow_service.WorkflowDraftVariableService"), + patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.DraftVarLoader"), + patch("services.workflow_service.HumanInputNode.extract_variable_selector_to_variable_mapping"), + patch("services.workflow_service.load_into_variable_pool"), + patch("services.workflow_service.WorkflowEntry.mapping_user_inputs_to_variable_pool"), + ): + service._build_human_input_variable_pool( + app_model=MagicMock(), workflow=workflow, node_config={}, manual_inputs={}, user_id="user-1" + ) + mock_pool_cls.assert_called_once() + + +# =========================================================================== +# TestWorkflowServiceFreeNodeExecution +# Tests for run_free_workflow_node and handle_single_step_result +# =========================================================================== + + +class TestWorkflowServiceFreeNodeExecution: + @pytest.fixture + def service(self) -> WorkflowService: + with patch("services.workflow_service.db"): + return WorkflowService() + + def test_run_free_workflow_node_success(self, service: WorkflowService) -> None: + node_execution = MagicMock() + with ( + patch.object(service, "_handle_single_step_result", return_value=node_execution), + patch("services.workflow_service.WorkflowEntry.run_free_node"), + ): + result = service.run_free_workflow_node({}, "tenant-1", "user-1", "node-1", {}) + assert result == node_execution + + def test_validate_graph_structure_coexist_error(self, service: WorkflowService) -> None: + graph = { + "nodes": [ + {"data": {"type": "start"}}, + {"data": {"type": "trigger-webhook"}}, # is_trigger_node=True + ] + } + with pytest.raises(ValueError, match="Start node and trigger nodes cannot coexist"): + service.validate_graph_structure(graph) + + def test_validate_features_structure_success(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "workflow" + features = {} + with patch("services.workflow_service.WorkflowAppConfigManager.config_validate") as mock_val: + service.validate_features_structure(app, features) + mock_val.assert_called_once() + + def test_validate_features_structure_invalid_mode(self, service: WorkflowService) -> None: + app = MagicMock() + app.mode = "invalid" + with pytest.raises(ValueError, match="Invalid app mode"): + service.validate_features_structure(app, {}) + + def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: + with patch( + "dify_graph.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + ): + with pytest.raises(ValueError, match="Invalid HumanInput node data"): + service._validate_human_input_node_data({}) + + def test_rebuild_single_file_unreachable(self) -> None: + # Test line 1523 (unreachable) + with pytest.raises(Exception, match="unreachable"): + _rebuild_single_file("tenant-1", {}, cast(Any, "invalid_type")) + + def test_build_human_input_node(self, service: WorkflowService) -> None: + """Cover _build_human_input_node (lines 1065-1088).""" + workflow = MagicMock() + workflow.id = "wf-1" + workflow.tenant_id = "t-1" + workflow.app_id = "app-1" + account = MagicMock() + account.id = "u-1" + node_config = {"id": "n-1"} + variable_pool = MagicMock() + + with ( + patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.HumanInputNode") as mock_node_cls, + patch("services.workflow_service.HumanInputFormRepositoryImpl"), + ): + node = service._build_human_input_node( + workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool + ) + assert node == mock_node_cls.return_value + mock_node_cls.assert_called_once() diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py new file mode 100644 index 0000000000..9bfd7eb2c5 --- /dev/null +++ b/api/tests/unit_tests/services/test_workspace_service.py @@ -0,0 +1,576 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from models.account import Tenant + +# --------------------------------------------------------------------------- +# Constants used throughout the tests +# --------------------------------------------------------------------------- + +TENANT_ID = "tenant-abc" +ACCOUNT_ID = "account-xyz" +FILES_BASE_URL = "https://files.example.com" + +DB_PATH = "services.workspace_service.db" +FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" +TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" +DIFY_CONFIG_PATH = "services.workspace_service.dify_config" +CURRENT_USER_PATH = "services.workspace_service.current_user" +CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" + + +# --------------------------------------------------------------------------- +# Helpers / factories +# --------------------------------------------------------------------------- + + +def _make_tenant( + tenant_id: str = TENANT_ID, + name: str = "My Workspace", + plan: str = "sandbox", + status: str = "active", + custom_config: dict | None = None, +) -> Tenant: + """Create a minimal Tenant-like namespace.""" + return cast( + Tenant, + SimpleNamespace( + id=tenant_id, + name=name, + plan=plan, + status=status, + created_at="2024-01-01T00:00:00Z", + custom_config_dict=custom_config or {}, + ), + ) + + +def _make_feature( + can_replace_logo: bool = False, + next_credit_reset_date: str | None = None, + billing_plan: str = "sandbox", +) -> MagicMock: + """Create a feature namespace matching what FeatureService.get_features returns.""" + feature = MagicMock() + feature.can_replace_logo = can_replace_logo + feature.next_credit_reset_date = next_credit_reset_date + feature.billing.subscription.plan = billing_plan + return feature + + +def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: + pool = MagicMock() + pool.quota_limit = quota_limit + pool.quota_used = quota_used + return pool + + +def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: + return SimpleNamespace(role=role) + + +def _tenant_info(result: object) -> dict[str, Any] | None: + return cast(dict[str, Any] | None, result) + + +# --------------------------------------------------------------------------- +# Shared fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def mock_current_user() -> SimpleNamespace: + """Return a lightweight current_user stand-in.""" + return SimpleNamespace(id=ACCOUNT_ID) + + +@pytest.fixture +def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """ + Patch the common external boundaries used by WorkspaceService.get_tenant_info. + + Returns a dict of named mocks so individual tests can customise them. + """ + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) + mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "SELF_HOSTED" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "has_roles": mock_has_roles, + "config": mock_config, + } + + +# --------------------------------------------------------------------------- +# 1. None Tenant Handling +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: + """get_tenant_info should short-circuit and return None for a falsy tenant.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = None + + # Act + result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) + + # Assert + assert result is None + + +def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: + """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" + from services.workspace_service import WorkspaceService + + # Arrange / Act / Assert + assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] + + +# --------------------------------------------------------------------------- +# 2. Basic Tenant Info — happy path +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_return_base_fields( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """get_tenant_info should always return the six base scalar fields.""" + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["id"] == TENANT_ID + assert result["name"] == "My Workspace" + assert result["plan"] == "sandbox" + assert result["status"] == "active" + assert result["created_at"] == "2024-01-01T00:00:00Z" + assert result["trial_end_reason"] is None + + +def test_get_tenant_info_should_populate_role_from_tenant_account_join( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The 'role' field should be taken from TenantAccountJoin, not the default.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["role"] == "admin" + + +def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The service asserts that TenantAccountJoin exists. + Missing join should raise AssertionError. + """ + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["query_chain"].first.return_value = None + tenant = _make_tenant() + + # Act + Assert + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + +# --------------------------------------------------------------------------- +# 3. Logo Customisation +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant( + custom_config={ + "replace_webapp_logo": True, + "remove_webapp_brand": True, + } + ) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" in result + assert result["custom_config"]["remove_webapp_brand"] is True + expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" + assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url + + +def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + +def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config should be absent when can_replace_logo is False.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """custom_config block is gated on OWNER or ADMIN role.""" + from services.workspace_service import WorkspaceService + + # Arrange + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = False # regular member + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "custom_config" not in result + + +def test_get_tenant_info_should_use_files_url_for_logo_url( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """The logo URL should use dify_config.FILES_URL as the base.""" + from services.workspace_service import WorkspaceService + + # Arrange + custom_base = "https://cdn.mycompany.io" + basic_mocks["config"].FILES_URL = custom_base + basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) + basic_mocks["has_roles"].return_value = True + tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + +# --------------------------------------------------------------------------- +# 4. Cloud-Edition Credit Features +# --------------------------------------------------------------------------- + +CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX + + +@pytest.fixture +def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: + """Patches for CLOUD edition tests, billing plan = professional by default.""" + mocker.patch(CURRENT_USER_PATH, mock_current_user) + + mock_db_session = mocker.patch(f"{DB_PATH}.session") + mock_query_chain = MagicMock() + mock_db_session.query.return_value = mock_query_chain + mock_query_chain.where.return_value = mock_query_chain + mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") + + mock_feature = mocker.patch( + FEATURE_SERVICE_PATH, + return_value=_make_feature( + can_replace_logo=False, + next_credit_reset_date="2025-02-01", + billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, + ), + ) + mocker.patch(TENANT_SERVICE_PATH, return_value=False) + mock_config = mocker.patch(DIFY_CONFIG_PATH) + mock_config.EDITION = "CLOUD" + mock_config.FILES_URL = FILES_BASE_URL + + return { + "db_session": mock_db_session, + "query_chain": mock_query_chain, + "get_features": mock_feature, + "config": mock_config, + } + + +def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """next_credit_reset_date should be present in CLOUD edition.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch( + CREDIT_POOL_SERVICE_PATH, + side_effect=[None, None], # both paid and trial pools absent + ) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + +def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=1000, quota_used=200) + mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + +def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """quota_limit == -1 means unlimited; service should still use the paid pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=-1, quota_used=999) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid pool is exhausted (used >= limit), switch to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full + trial_pool = _make_pool(quota_limit=100, quota_used=10) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When paid_pool is None, fall back to trial pool.""" + from services.workspace_service import WorkspaceService + + # Arrange + trial_pool = _make_pool(quota_limit=50, quota_used=5) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + +def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """ + When the subscription plan IS SANDBOX, the paid pool branch is skipped + entirely and we fall back to the trial pool. + """ + from enums.cloud_plan import CloudPlan + from services.workspace_service import WorkspaceService + + # Arrange — override billing plan to SANDBOX + cloud_mocks["get_features"].return_value = _make_feature( + next_credit_reset_date="2025-02-01", + billing_plan=CloudPlan.SANDBOX, + ) + paid_pool = _make_pool(quota_limit=1000, quota_used=0) + trial_pool = _make_pool(quota_limit=200, quota_used=20) + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + +def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( + mocker: MockerFixture, + cloud_mocks: dict, +) -> None: + """When both paid and trial pools are absent, trial_credits should not be set.""" + from services.workspace_service import WorkspaceService + + # Arrange + mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 5. Self-hosted / Non-Cloud Edition +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + from services.workspace_service import WorkspaceService + + # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") + tenant = _make_tenant() + + # Act + result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) + + # Assert + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + +# --------------------------------------------------------------------------- +# 6. DB query integrity +# --------------------------------------------------------------------------- + + +def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( + mocker: MockerFixture, + basic_mocks: dict, +) -> None: + """ + The DB query for TenantAccountJoin must be scoped to the correct + tenant_id and current_user.id. + """ + from services.workspace_service import WorkspaceService + + # Arrange + tenant = _make_tenant(tenant_id="my-special-tenant") + mock_current_user = mocker.patch(CURRENT_USER_PATH) + mock_current_user.id = "special-user-id" + + # Act + WorkspaceService.get_tenant_info(tenant) + + # Assert — db.session.query was invoked (at least once) + basic_mocks["db_session"].query.assert_called() diff --git a/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py new file mode 100644 index 0000000000..ce44818886 --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_api_tools_manage_service.py @@ -0,0 +1,643 @@ +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture + +from core.tools.entities.tool_entities import ApiProviderSchemaType +from services.tools.api_tools_manage_service import ApiToolManageService + + +@pytest.fixture +def mock_db(mocker: MockerFixture) -> MagicMock: + # Arrange + mocked_db = mocker.patch("services.tools.api_tools_manage_service.db") + mocked_db.session = MagicMock() + return mocked_db + + +def _tool_bundle(operation_id: str = "tool-1") -> SimpleNamespace: + return SimpleNamespace(operation_id=operation_id) + + +def test_parser_api_schema_should_return_schema_payload_when_schema_is_valid(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI.value), + ) + + # Act + result = ApiToolManageService.parser_api_schema("valid-schema") + + # Assert + assert result["schema_type"] == ApiProviderSchemaType.OPENAPI.value + assert len(result["credentials_schema"]) == 3 + assert "warning" in result + + +def test_parser_api_schema_should_raise_value_error_when_parser_raises(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=RuntimeError("bad schema"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema: invalid schema: bad schema"): + ApiToolManageService.parser_api_schema("invalid") + + +def test_convert_schema_to_tool_bundles_should_return_tool_bundles_when_valid(mocker: MockerFixture) -> None: + # Arrange + expected = ([_tool_bundle("a"), _tool_bundle("b")], ApiProviderSchemaType.SWAGGER) + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=expected, + ) + extra_info: dict[str, str] = {} + + # Act + result = ApiToolManageService.convert_schema_to_tool_bundles("schema", extra_info=extra_info) + + # Assert + assert result == expected + + +def test_convert_schema_to_tool_bundles_should_raise_value_error_when_parser_fails(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=ValueError("parse failed"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema: parse failed"): + ApiToolManageService.convert_schema_to_tool_bundles("schema") + + +def test_create_api_tool_provider_should_raise_error_when_provider_already_exists( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="provider provider-a already exists"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name=" provider-a ", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_raise_error_when_tool_count_exceeds_limit( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + many_tools = [_tool_bundle(str(i)) for i in range(101)] + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=(many_tools, ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="the number of apis should be less than 100"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_raise_error_when_auth_type_is_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=[], + ) + + +def test_create_api_tool_provider_should_create_provider_when_input_is_valid( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + mock_controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=mock_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.encrypt.return_value = {"auth_type": "none"} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") + + # Act + result = ApiToolManageService.create_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + icon={"emoji": "X"}, + credentials={"auth_type": "none"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=["news"], + ) + + # Assert + assert result == {"result": "success"} + mock_controller.load_bundled_tools.assert_called_once() + mock_db.session.add.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_get_api_tool_provider_remote_schema_should_return_schema_when_response_is_valid( + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.get", + return_value=SimpleNamespace(status_code=200, text="schema-content"), + ) + mocker.patch.object(ApiToolManageService, "parser_api_schema", return_value={"ok": True}) + + # Act + result = ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") + + # Assert + assert result == {"schema": "schema-content"} + + +@pytest.mark.parametrize("status_code", [400, 404, 500]) +def test_get_api_tool_provider_remote_schema_should_raise_error_when_remote_fetch_is_invalid( + status_code: int, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.get", + return_value=SimpleNamespace(status_code=status_code, text="schema-content"), + ) + mock_logger = mocker.patch("services.tools.api_tools_manage_service.logger") + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema, please check the url you provided"): + ApiToolManageService.get_api_tool_provider_remote_schema("user-1", "tenant-1", "https://schema") + mock_logger.exception.assert_called_once() + + +def test_list_api_tool_provider_tools_should_raise_error_when_provider_not_found( + mock_db: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="you have not added provider provider-a"): + ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") + + +def test_list_api_tool_provider_tools_should_return_converted_tools_when_provider_exists( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace(tools=[_tool_bundle("tool-a"), _tool_bundle("tool-b")]) + mock_db.session.query.return_value.where.return_value.first.return_value = provider + controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", + return_value=controller, + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["search"]) + mock_convert = mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", + side_effect=[{"name": "tool-a"}, {"name": "tool-b"}], + ) + + # Act + result = ApiToolManageService.list_api_tool_provider_tools("user-1", "tenant-1", "provider-a") + + # Assert + assert result == [{"name": "tool-a"}, {"name": "tool-b"}] + assert mock_convert.call_count == 2 + + +def test_update_api_tool_provider_should_raise_error_when_original_provider_not_found( + mock_db: MagicMock, +) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="api provider provider-a does not exists"): + ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + original_provider="provider-a", + icon={}, + credentials={"auth_type": "none"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy=None, + custom_disclaimer="custom", + labels=[], + ) + + +def test_update_api_tool_provider_should_raise_error_when_auth_type_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace(credentials={}, name="old") + mock_db.session.query.return_value.where.return_value.first.return_value = provider + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-a", + original_provider="provider-a", + icon={}, + credentials={}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy=None, + custom_disclaimer="custom", + labels=[], + ) + + +def test_update_api_tool_provider_should_update_provider_and_preserve_masked_credentials( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = SimpleNamespace( + credentials={"auth_type": "none", "api_key_value": "encrypted-old"}, + name="old", + icon="", + schema="", + description="", + schema_type_str="", + tools_str="", + privacy_policy="", + custom_disclaimer="", + credentials_str="", + ) + mock_db.session.query.return_value.where.return_value.first.return_value = provider + mocker.patch.object( + ApiToolManageService, + "convert_schema_to_tool_bundles", + return_value=([_tool_bundle()], ApiProviderSchemaType.OPENAPI), + ) + controller = MagicMock() + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=controller, + ) + cache = MagicMock() + encrypter = MagicMock() + encrypter.decrypt.return_value = {"auth_type": "none", "api_key_value": "plain-old"} + encrypter.mask_plugin_credentials.return_value = {"api_key_value": "***"} + encrypter.encrypt.return_value = {"auth_type": "none", "api_key_value": "encrypted-new"} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(encrypter, cache), + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.update_tool_labels") + + # Act + result = ApiToolManageService.update_api_tool_provider( + user_id="user-1", + tenant_id="tenant-1", + provider_name="provider-new", + original_provider="provider-old", + icon={"emoji": "E"}, + credentials={"auth_type": "none", "api_key_value": "***"}, + _schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + privacy_policy="privacy", + custom_disclaimer="custom", + labels=["news"], + ) + + # Assert + assert result == {"result": "success"} + assert provider.name == "provider-new" + assert provider.privacy_policy == "privacy" + assert provider.credentials_str != "" + cache.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + +def test_delete_api_tool_provider_should_raise_error_when_provider_missing(mock_db: MagicMock) -> None: + # Arrange + mock_db.session.query.return_value.where.return_value.first.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="you have not added provider provider-a"): + ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") + + +def test_delete_api_tool_provider_should_delete_provider_when_exists(mock_db: MagicMock) -> None: + # Arrange + provider = object() + mock_db.session.query.return_value.where.return_value.first.return_value = provider + + # Act + result = ApiToolManageService.delete_api_tool_provider("user-1", "tenant-1", "provider-a") + + # Assert + assert result == {"result": "success"} + mock_db.session.delete.assert_called_once_with(provider) + mock_db.session.commit.assert_called_once() + + +def test_get_api_tool_provider_should_delegate_to_tool_manager(mocker: MockerFixture) -> None: + # Arrange + expected = {"provider": "value"} + mock_get = mocker.patch( + "services.tools.api_tools_manage_service.ToolManager.user_get_api_provider", + return_value=expected, + ) + + # Act + result = ApiToolManageService.get_api_tool_provider("user-1", "tenant-1", "provider-a") + + # Assert + assert result == expected + mock_get.assert_called_once_with(provider="provider-a", tenant_id="tenant-1") + + +def test_test_api_tool_preview_should_raise_error_for_invalid_schema_type() -> None: + # Arrange + schema_type = "bad-schema-type" + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema type"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=schema_type, # type: ignore[arg-type] + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_schema_parser_fails(mocker: MockerFixture) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + side_effect=RuntimeError("invalid"), + ) + + # Act + Assert + with pytest.raises(ValueError, match="invalid schema"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_tool_name_is_invalid( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") + + # Act + Assert + with pytest.raises(ValueError, match="invalid tool name tool-b"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-b", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_raise_error_when_auth_type_missing( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace(id="provider-id") + + # Act + Assert + with pytest.raises(ValueError, match="auth_type is required"): + ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + +def test_test_api_tool_preview_should_return_error_payload_when_tool_validation_raises( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) + mock_db.session.query.return_value.where.return_value.first.return_value = db_provider + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + provider_controller = MagicMock() + tool_obj = MagicMock() + tool_obj.fork_tool_runtime.return_value = tool_obj + tool_obj.validate_credentials.side_effect = ValueError("validation failed") + provider_controller.get_tool.return_value = tool_obj + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=provider_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"auth_type": "none"} + mock_encrypter.mask_plugin_credentials.return_value = {} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + + # Act + result = ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + # Assert + assert result == {"error": "validation failed"} + + +def test_test_api_tool_preview_should_return_result_payload_when_validation_succeeds( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = SimpleNamespace(id="provider-id", credentials={"auth_type": "none"}) + mock_db.session.query.return_value.where.return_value.first.return_value = db_provider + mocker.patch( + "services.tools.api_tools_manage_service.ApiBasedToolSchemaParser.auto_parse_to_tool_bundle", + return_value=([_tool_bundle("tool-a")], ApiProviderSchemaType.OPENAPI), + ) + provider_controller = MagicMock() + tool_obj = MagicMock() + tool_obj.fork_tool_runtime.return_value = tool_obj + tool_obj.validate_credentials.return_value = {"ok": True} + provider_controller.get_tool.return_value = tool_obj + mocker.patch( + "services.tools.api_tools_manage_service.ApiToolProviderController.from_db", + return_value=provider_controller, + ) + mock_encrypter = MagicMock() + mock_encrypter.decrypt.return_value = {"auth_type": "none"} + mock_encrypter.mask_plugin_credentials.return_value = {} + mocker.patch( + "services.tools.api_tools_manage_service.create_tool_provider_encrypter", + return_value=(mock_encrypter, MagicMock()), + ) + + # Act + result = ApiToolManageService.test_api_tool_preview( + tenant_id="tenant-1", + provider_name="provider-a", + tool_name="tool-a", + credentials={"auth_type": "none"}, + parameters={"x": "1"}, + schema_type=ApiProviderSchemaType.OPENAPI, + schema="schema", + ) + + # Assert + assert result == {"result": {"ok": True}} + + +def test_list_api_tools_should_return_all_user_providers_with_converted_tools( + mock_db: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_one = SimpleNamespace(name="p1") + provider_two = SimpleNamespace(name="p2") + mock_db.session.scalars.return_value.all.return_value = [provider_one, provider_two] + + controller_one = MagicMock() + controller_one.get_tools.return_value = ["tool-a"] + controller_two = MagicMock() + controller_two.get_tools.return_value = ["tool-b", "tool-c"] + + user_provider_one = SimpleNamespace(labels=[], tools=[]) + user_provider_two = SimpleNamespace(labels=[], tools=[]) + + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_controller", + side_effect=[controller_one, controller_two], + ) + mocker.patch("services.tools.api_tools_manage_service.ToolLabelManager.get_tool_labels", return_value=["news"]) + mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.api_provider_to_user_provider", + side_effect=[user_provider_one, user_provider_two], + ) + mocker.patch("services.tools.api_tools_manage_service.ToolTransformService.repack_provider") + mock_convert = mocker.patch( + "services.tools.api_tools_manage_service.ToolTransformService.convert_tool_entity_to_api_entity", + side_effect=[{"name": "tool-a"}, {"name": "tool-b"}, {"name": "tool-c"}], + ) + + # Act + result = ApiToolManageService.list_api_tools("tenant-1") + + # Assert + assert len(result) == 2 + assert user_provider_one.tools == [{"name": "tool-a"}] + assert user_provider_two.tools == [{"name": "tool-b"}, {"name": "tool-c"}] + assert mock_convert.call_count == 3 diff --git a/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py new file mode 100644 index 0000000000..d35e014fab --- /dev/null +++ b/api/tests/unit_tests/services/tools/test_mcp_tools_manage_service.py @@ -0,0 +1,1045 @@ +from __future__ import annotations + +import hashlib +import json +from datetime import datetime +from types import SimpleNamespace +from typing import cast +from unittest.mock import MagicMock + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.exc import IntegrityError + +from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration, MCPProviderEntity +from core.mcp.entities import AuthActionType +from core.mcp.error import MCPAuthError, MCPError +from models.tools import MCPToolProvider +from services.tools.mcp_tools_manage_service import ( + EMPTY_CREDENTIALS_JSON, + EMPTY_TOOLS_JSON, + UNCHANGED_SERVER_URL_PLACEHOLDER, + MCPToolManageService, + OAuthDataType, + ProviderUrlValidationData, + ReconnectResult, + ServerUrlValidationResult, +) + + +class _ToolStub: + def __init__(self, name: str, description: str | None) -> None: + self._name = name + self._description = description + + def model_dump(self) -> dict[str, str | None]: + return {"name": self._name, "description": self._description} + + +@pytest.fixture +def mock_session() -> MagicMock: + # Arrange + return MagicMock() + + +@pytest.fixture +def service(mock_session: MagicMock) -> MCPToolManageService: + # Arrange + return MCPToolManageService(session=mock_session) + + +def _provider_entity_stub(*, authed: bool = True) -> MCPProviderEntity: + return cast( + MCPProviderEntity, + SimpleNamespace( + authed=authed, + timeout=30.0, + sse_read_timeout=300.0, + provider_id="server-1", + headers={"x-api-key": "enc"}, + decrypt_headers=lambda: {"x-api-key": "key"}, + retrieve_tokens=lambda: SimpleNamespace(token_type="bearer", access_token="token-1"), + decrypt_server_url=lambda: "https://mcp.example.com/sse", + to_api_response=lambda user_name=None: { + "id": "provider-1", + "author": user_name or "Anonymous", + "name": "MCP Tool", + "description": {"en_US": "", "zh_Hans": ""}, + "icon": "icon", + "label": {"en_US": "MCP Tool", "zh_Hans": "MCP Tool"}, + "type": "mcp", + "is_team_authorization": True, + "server_url": "https://mcp.example.com/******", + "updated_at": 1, + "server_identifier": "server-1", + "configuration": {"timeout": "30", "sse_read_timeout": "300"}, + "masked_headers": {}, + "is_dynamic_registration": True, + }, + decrypt_credentials=lambda: {"client_id": "plain-id", "client_secret": "plain-secret"}, + masked_credentials=lambda: {"client_id": "pl***id", "client_secret": "pl***et"}, + masked_headers=lambda: {"x-api-key": "ke***ey"}, + ), + ) + + +def _provider_stub(*, authed: bool = True) -> MCPToolProvider: + entity = _provider_entity_stub(authed=authed) + return cast( + MCPToolProvider, + SimpleNamespace( + id="provider-1", + tenant_id="tenant-1", + user_id="user-1", + name="Provider A", + server_identifier="server-1", + server_url="encrypted-url", + server_url_hash="old-hash", + authed=authed, + tools=EMPTY_TOOLS_JSON, + encrypted_credentials=json.dumps({"existing": "credential"}), + encrypted_headers=json.dumps({"x-api-key": "enc"}), + credentials={"existing": "credential"}, + timeout=30.0, + sse_read_timeout=300.0, + updated_at=datetime.now(), + icon="icon", + to_entity=lambda: entity, + load_user=lambda: SimpleNamespace(name="Tester"), + ), + ) + + +def test_server_url_validation_result_should_update_server_url_when_all_conditions_match() -> None: + # Arrange + result = ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}"), + ) + + # Act + should_update = result.should_update_server_url + + # Assert + assert should_update is True + + +def test_get_provider_should_return_provider_when_exists( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + provider = _provider_stub() + mock_session.scalar.return_value = provider + + # Act + result = service.get_provider(provider_id="provider-1", tenant_id="tenant-1") + + # Assert + assert result is provider + + +def test_get_provider_should_raise_error_when_provider_not_found( + service: MCPToolManageService, mock_session: MagicMock +) -> None: + # Arrange + mock_session.scalar.return_value = None + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool not found"): + service.get_provider(provider_id="provider-404", tenant_id="tenant-1") + + +def test_get_provider_entity_should_get_entity_by_provider_id_when_by_server_id_is_false( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_entity("provider-1", "tenant-1", by_server_id=False) + + # Assert + assert result is provider.to_entity() + mock_get_provider.assert_called_once_with(provider_id="provider-1", tenant_id="tenant-1") + + +def test_get_provider_entity_should_get_entity_by_server_identifier_when_by_server_id_is_true( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mock_get_provider = mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_entity("server-1", "tenant-1", by_server_id=True) + + # Assert + assert result is provider.to_entity() + mock_get_provider.assert_called_once_with(server_identifier="server-1", tenant_id="tenant-1") + + +def test_create_provider_should_raise_error_when_server_url_is_invalid(service: MCPToolManageService) -> None: + # Arrange + config = MCPConfiguration(timeout=30, sse_read_timeout=300) + + # Act + Assert + with pytest.raises(ValueError, match="Server URL is not valid"): + service.create_provider( + tenant_id="tenant-1", + name="Provider A", + server_url="invalid-url", + user_id="user-1", + icon="icon", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=config, + ) + + +def test_create_provider_should_create_and_return_user_provider_when_input_is_valid( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + config = MCPConfiguration(timeout=42, sse_read_timeout=123) + auth_data = MCPAuthentication(client_id="client-id", client_secret="secret") + mocker.patch.object(service, "_check_provider_exists") + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="encrypted-url") + mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x":"enc"}') + mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') + mocker.patch.object(service, "_prepare_icon", return_value='{"content":"😀"}') + expected_user_provider = {"id": "provider-1"} + mock_convert = mocker.patch( + "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", + return_value=expected_user_provider, + ) + + # Act + result = service.create_provider( + tenant_id="tenant-1", + name="Provider A", + server_url="https://mcp.example.com", + user_id="user-1", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=config, + authentication=auth_data, + headers={"x-api-key": "v1"}, + ) + + # Assert + assert result == expected_user_provider + mock_session.add.assert_called_once() + mock_session.flush.assert_called_once() + mock_convert.assert_called_once() + + +def test_update_provider_should_raise_error_when_new_name_conflicts( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = object() + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="New Name", + server_url="https://mcp.example.com", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=MCPConfiguration(), + ) + + +def test_update_provider_should_update_fields_when_input_is_valid( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + validation = ServerUrlValidationResult( + needs_validation=True, + validation_passed=True, + reconnect_result=ReconnectResult(authed=True, tools='[{"name":"t"}]', encrypted_credentials='{"x":"y"}'), + encrypted_server_url="new-encrypted-url", + server_url_hash="new-hash", + ) + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = None + mocker.patch.object(service, "_prepare_icon", return_value="new-icon") + mocker.patch.object(service, "_process_headers", return_value='{"x":"enc"}') + mocker.patch.object(service, "_process_credentials", return_value='{"client":"enc"}') + + # Act + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="Provider B", + server_url="https://mcp.example.com/new", + icon="😎", + icon_type="emoji", + icon_background="#000", + server_identifier="server-2", + headers={"x-api-key": "v2"}, + configuration=MCPConfiguration(timeout=50, sse_read_timeout=120), + authentication=MCPAuthentication(client_id="new-id", client_secret="new-secret"), + validation_result=validation, + ) + + # Assert + assert provider.name == "Provider B" + assert provider.server_identifier == "server-2" + assert provider.server_url == "new-encrypted-url" + assert provider.server_url_hash == "new-hash" + assert provider.authed is True + assert provider.tools == '[{"name":"t"}]' + assert provider.encrypted_credentials == '{"client":"enc"}' + assert provider.encrypted_headers == '{"x":"enc"}' + assert provider.timeout == 50 + assert provider.sse_read_timeout == 120 + mock_session.flush.assert_called_once() + + +def test_update_provider_should_handle_integrity_error_with_readable_message( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + mock_session.scalar.return_value = None + mocker.patch.object(service, "_prepare_icon", return_value="icon") + mock_session.flush.side_effect = IntegrityError("stmt", {}, Exception("unique_mcp_provider_name")) + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool Provider A already exists"): + service.update_provider( + tenant_id="tenant-1", + provider_id="provider-1", + name="Provider A", + server_url="https://mcp.example.com", + icon="😀", + icon_type="emoji", + icon_background="#fff", + server_identifier="server-1", + configuration=MCPConfiguration(), + ) + + +def test_delete_provider_should_delete_existing_provider( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + service.delete_provider(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + mock_session.delete.assert_called_once_with(provider) + + +def test_list_providers_should_return_empty_list_when_no_provider_exists( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.scalars.return_value.all.return_value = [] + + # Act + result = service.list_providers(tenant_id="tenant-1") + + # Assert + assert result == [] + + +def test_list_providers_should_convert_all_providers_and_attach_user_names( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider_1 = _provider_stub() + provider_2 = _provider_stub() + provider_2.user_id = "user-2" + mock_session.scalars.return_value.all.return_value = [provider_1, provider_2] + mock_session.query.return_value.where.return_value.all.return_value = [ + SimpleNamespace(id="user-1", name="Alice"), + SimpleNamespace(id="user-2", name="Bob"), + ] + mock_convert = mocker.patch( + "services.tools.mcp_tools_manage_service.ToolTransformService.mcp_provider_to_user_provider", + side_effect=[{"id": "1"}, {"id": "2"}], + ) + + # Act + result = service.list_providers(tenant_id="tenant-1", for_list=True, include_sensitive=False) + + # Assert + assert result == [{"id": "1"}, {"id": "2"}] + assert mock_convert.call_count == 2 + + +def test_list_provider_tools_should_raise_error_when_provider_is_not_authenticated( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=False) + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + Assert + with pytest.raises(ValueError, match="Please auth the tool first"): + service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + +def test_list_provider_tools_should_raise_error_when_remote_client_fails( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPError("connection failed") + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + Assert + with pytest.raises(ValueError, match="Failed to connect to MCP server"): + service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + +def test_list_provider_tools_should_update_db_and_return_response_on_success( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [ + _ToolStub("tool-a", None), + _ToolStub("tool-b", "desc"), + ] + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) + + # Act + result = service.list_provider_tools(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + assert result.plugin_unique_identifier == "server-1" + assert provider.authed is True + payload = json.loads(provider.tools) + assert payload[0]["description"] == "" + assert payload[1]["description"] == "desc" + mock_session.flush.assert_called_once() + + +def test_update_provider_credentials_should_update_encrypted_credentials_and_auth_state( + service: MCPToolManageService, + mock_session: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + provider.encrypted_credentials = json.dumps({"existing": "value"}) + mocker.patch.object(service, "get_provider", return_value=provider) + mock_controller = MagicMock() + mocker.patch("core.tools.mcp_tool.provider.MCPToolProviderController.from_db", return_value=mock_controller) + mock_encryptor = MagicMock() + mock_encryptor.encrypt.return_value = {"access_token": "encrypted-token"} + mocker.patch("services.tools.mcp_tools_manage_service.ProviderConfigEncrypter", return_value=mock_encryptor) + + # Act + service.update_provider_credentials( + provider_id="provider-1", + tenant_id="tenant-1", + credentials={"access_token": "plain-token"}, + authed=False, + ) + + # Assert + assert provider.authed is False + assert provider.tools == EMPTY_TOOLS_JSON + assert json.loads(cast(str, provider.encrypted_credentials))["access_token"] == "encrypted-token" + mock_session.flush.assert_called_once() + + +@pytest.mark.parametrize( + ("data_type", "data", "expected_authed"), + [ + (OAuthDataType.TOKENS, {"access_token": "token"}, True), + (OAuthDataType.MIXED, {"access_token": "token"}, True), + (OAuthDataType.MIXED, {"client_id": "id"}, None), + (OAuthDataType.CLIENT_INFO, {"client_id": "id"}, None), + ], +) +def test_save_oauth_data_should_delegate_with_expected_authed_value( + data_type: OAuthDataType, + data: dict[str, str], + expected_authed: bool | None, + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_update = mocker.patch.object(service, "update_provider_credentials") + + # Act + service.save_oauth_data("provider-1", "tenant-1", data, data_type) + + # Assert + assert mock_update.call_args.kwargs["authed"] == expected_authed + + +def test_clear_provider_credentials_should_reset_provider_state( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub(authed=True) + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + service.clear_provider_credentials(provider_id="provider-1", tenant_id="tenant-1") + + # Assert + assert provider.tools == EMPTY_TOOLS_JSON + assert provider.encrypted_credentials == EMPTY_CREDENTIALS_JSON + assert provider.authed is False + + +def test_check_provider_exists_should_raise_different_errors_for_conflicts( + service: MCPToolManageService, + mock_session: MagicMock, +) -> None: + # Arrange + mock_session.scalar.return_value = SimpleNamespace( + name="name-a", + server_url_hash="hash-a", + server_identifier="server-a", + ) + + # Act + Assert + with pytest.raises(ValueError, match="MCP tool name-a already exists"): + service._check_provider_exists("tenant-1", "name-a", "hash-b", "server-b") + with pytest.raises(ValueError, match="MCP tool with this server URL already exists"): + service._check_provider_exists("tenant-1", "name-b", "hash-a", "server-b") + with pytest.raises(ValueError, match="MCP tool server-a already exists"): + service._check_provider_exists("tenant-1", "name-b", "hash-b", "server-a") + + +def test_prepare_icon_should_return_json_for_emoji_and_raw_value_for_non_emoji(service: MCPToolManageService) -> None: + # Arrange + # Act + emoji_icon = service._prepare_icon("😀", "emoji", "#fff") + raw_icon = service._prepare_icon("https://icon.png", "file", "#000") + + # Assert + assert json.loads(emoji_icon)["content"] == "😀" + assert raw_icon == "https://icon.png" + + +def test_encrypt_dict_fields_should_encrypt_secret_fields(service: MCPToolManageService, mocker: MockerFixture) -> None: + # Arrange + mock_encryptor = MagicMock() + mock_encryptor.encrypt.return_value = {"Authorization": "enc-token"} + mocker.patch("core.tools.utils.encryption.create_provider_encrypter", return_value=(mock_encryptor, MagicMock())) + + # Act + result = service._encrypt_dict_fields({"Authorization": "token"}, ["Authorization"], "tenant-1") + + # Assert + assert result == {"Authorization": "enc-token"} + + +def test_prepare_encrypted_dict_should_return_json_string(service: MCPToolManageService, mocker: MockerFixture) -> None: + # Arrange + mocker.patch.object(service, "_encrypt_dict_fields", return_value={"x": "enc"}) + + # Act + result = service._prepare_encrypted_dict({"x": "v"}, "tenant-1") + + # Assert + assert result == '{"x": "enc"}' + + +def test_prepare_auth_headers_should_append_authorization_when_tokens_exist(service: MCPToolManageService) -> None: + # Arrange + provider_entity = _provider_entity_stub() + + # Act + headers = service._prepare_auth_headers(provider_entity) + + # Assert + assert headers["Authorization"] == "Bearer token-1" + + +def test_retrieve_remote_mcp_tools_should_return_tools_from_client( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", "desc")] + mock_client_cls = mocker.patch("services.tools.mcp_tools_manage_service.MCPClientWithAuthRetry") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + tools = service._retrieve_remote_mcp_tools("https://mcp.example.com", {}, _provider_entity_stub()) + + # Assert + assert len(tools) == 1 + assert tools[0].model_dump()["name"] == "tool-a" + + +def test_execute_auth_actions_should_dispatch_supported_actions( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mock_save = mocker.patch.object(service, "save_oauth_data") + auth_result = SimpleNamespace( + actions=[ + SimpleNamespace( + action_type=AuthActionType.SAVE_CLIENT_INFO, + data={"client_id": "c1"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_TOKENS, + data={"access_token": "t1"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_CODE_VERIFIER, + data={"code_verifier": "cv"}, + provider_id="provider-1", + tenant_id="tenant-1", + ), + SimpleNamespace( + action_type=AuthActionType.SAVE_TOKENS, + data={"access_token": "skip"}, + provider_id=None, + tenant_id="tenant-1", + ), + ], + response={"ok": "1"}, + ) + + # Act + result = service.execute_auth_actions(auth_result) + + # Assert + assert result == {"ok": "1"} + assert mock_save.call_count == 3 + + +def test_auth_with_actions_should_call_auth_and_execute_actions( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider_entity = _provider_entity_stub() + auth_result = SimpleNamespace(actions=[], response={"status": "ok"}) + mocker.patch("services.tools.mcp_tools_manage_service.auth", return_value=auth_result) + mock_execute = mocker.patch.object(service, "execute_auth_actions", return_value={"status": "ok"}) + + # Act + result = service.auth_with_actions(provider_entity=provider_entity, authorization_code="code-1") + + # Assert + assert result == {"status": "ok"} + mock_execute.assert_called_once_with(auth_result) + + +def test_get_provider_for_url_validation_should_return_validation_data( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "get_provider", return_value=provider) + + # Act + result = service.get_provider_for_url_validation(tenant_id="tenant-1", provider_id="provider-1") + + # Assert + assert result.current_server_url_hash == "old-hash" + assert result.headers == {"x-api-key": "enc"} + + +def test_validate_server_url_standalone_should_skip_validation_for_unchanged_placeholder() -> None: + # Arrange + data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=UNCHANGED_SERVER_URL_PLACEHOLDER, + validation_data=data, + ) + + # Assert + assert result.needs_validation is False + + +def test_validate_server_url_standalone_should_raise_error_for_invalid_url() -> None: + # Arrange + data = ProviderUrlValidationData(current_server_url_hash="hash", headers={}, timeout=30, sse_read_timeout=300) + + # Act + Assert + with pytest.raises(ValueError, match="Server URL is not valid"): + MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url="bad-url", + validation_data=data, + ) + + +def test_validate_server_url_standalone_should_return_no_validation_when_hash_unchanged(mocker: MockerFixture) -> None: + # Arrange + url = "https://mcp.example.com" + current_hash = hashlib.sha256(url.encode()).hexdigest() + data = ProviderUrlValidationData(current_server_url_hash=current_hash, headers={}, timeout=30, sse_read_timeout=300) + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-url") + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=url, + validation_data=data, + ) + + # Assert + assert result.needs_validation is False + assert result.encrypted_server_url == "enc-url" + assert result.server_url_hash == current_hash + + +def test_validate_server_url_standalone_should_reconnect_when_url_changes(mocker: MockerFixture) -> None: + # Arrange + url = "https://mcp-new.example.com" + data = ProviderUrlValidationData(current_server_url_hash="old", headers={}, timeout=30, sse_read_timeout=300) + reconnect_result = ReconnectResult(authed=True, tools='[{"name":"x"}]', encrypted_credentials="{}") + mocker.patch("services.tools.mcp_tools_manage_service.encrypter.encrypt_token", return_value="enc-new") + mock_reconnect = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=reconnect_result) + + # Act + result = MCPToolManageService.validate_server_url_standalone( + tenant_id="tenant-1", + new_server_url=url, + validation_data=data, + ) + + # Assert + assert result.validation_passed is True + assert result.reconnect_result == reconnect_result + mock_reconnect.assert_called_once() + + +def test_reconnect_with_url_should_delegate_to_private_method(mocker: MockerFixture) -> None: + # Arrange + expected = ReconnectResult(authed=True, tools="[]", encrypted_credentials="{}") + mock_delegate = mocker.patch.object(MCPToolManageService, "_reconnect_with_url", return_value=expected) + + # Act + result = MCPToolManageService.reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result == expected + mock_delegate.assert_called_once() + + +def test_private_reconnect_with_url_should_return_authed_true_when_connection_succeeds(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.return_value = [_ToolStub("tool-a", None)] + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + result = MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result.authed is True + assert json.loads(result.tools)[0]["description"] == "" + + +def test_private_reconnect_with_url_should_return_authed_false_on_auth_error(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPAuthError("auth required") + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + result = MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + # Assert + assert result.authed is False + assert result.tools == EMPTY_TOOLS_JSON + + +def test_private_reconnect_with_url_should_raise_value_error_on_mcp_error(mocker: MockerFixture) -> None: + # Arrange + mcp_client_instance = MagicMock() + mcp_client_instance.list_tools.side_effect = MCPError("network failure") + mock_client_cls = mocker.patch("core.mcp.mcp_client.MCPClient") + mock_client_cls.return_value.__enter__.return_value = mcp_client_instance + + # Act + Assert + with pytest.raises(ValueError, match="Failed to re-connect MCP server: network failure"): + MCPToolManageService._reconnect_with_url( + server_url="https://mcp.example.com", + headers={}, + timeout=30, + sse_read_timeout=300, + ) + + +def test_build_tool_provider_response_should_build_api_entity_with_tools( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + db_provider = _provider_stub() + provider_entity = _provider_entity_stub() + tools = [_ToolStub("tool-a", "desc")] + mocker.patch("services.tools.mcp_tools_manage_service.ToolTransformService.mcp_tool_to_user_tool", return_value=[]) + + # Act + result = service._build_tool_provider_response(db_provider, provider_entity, tools) + + # Assert + assert result.plugin_unique_identifier == "server-1" + assert result.name == "MCP Tool" + + +@pytest.mark.parametrize( + ("orig_message", "expected_error"), + [ + ("unique_mcp_provider_name", "MCP tool name already exists"), + ("unique_mcp_provider_server_url", "MCP tool https://mcp.example.com already exists"), + ("unique_mcp_provider_server_identifier", "MCP tool server-1 already exists"), + ], +) +def test_handle_integrity_error_should_raise_readable_value_errors( + orig_message: str, + expected_error: str, + service: MCPToolManageService, +) -> None: + """Test that known integrity errors raise readable value errors.""" + # Arrange + error = IntegrityError("stmt", {}, Exception(orig_message)) + + # Act + Assert + with pytest.raises(ValueError, match=expected_error): + service._handle_integrity_error(error, "name", "https://mcp.example.com", "server-1") + + +def test_handle_integrity_error_should_reraise_unknown_error(service: MCPToolManageService) -> None: + """Test that unknown integrity errors are re-raised.""" + # Arrange + error = IntegrityError("stmt", {}, Exception("unknown-constraint")) + + # Act + Assert + with pytest.raises(IntegrityError) as exc_info: + service._handle_integrity_error(error, "name", "url", "identifier") + + assert exc_info.value is error + + +@pytest.mark.parametrize( + ("url", "expected"), + [ + ("https://mcp.example.com", True), + ("http://mcp.example.com", True), + ("", False), + ("invalid", False), + ("ftp://mcp.example.com", False), + ], +) +def test_is_valid_url_should_validate_supported_schemes( + url: str, + expected: bool, + service: MCPToolManageService, +) -> None: + # Arrange + # Act + result = service._is_valid_url(url) + + # Assert + assert result is expected + + +def test_update_optional_fields_should_update_only_non_none_values(service: MCPToolManageService) -> None: + # Arrange + provider = _provider_stub() + configuration = MCPConfiguration(timeout=99, sse_read_timeout=300) + + # Act + service._update_optional_fields(provider, configuration) + + # Assert + assert provider.timeout == 99 + assert provider.sse_read_timeout == 300 + + +def test_process_headers_should_return_none_when_empty_headers(service: MCPToolManageService) -> None: + # Arrange + provider = _provider_stub() + + # Act + result = service._process_headers({}, provider, "tenant-1") + + # Assert + assert result is None + + +def test_process_headers_should_merge_and_encrypt_headers( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + mocker.patch.object(service, "_merge_headers_with_masked", return_value={"x-api-key": "plain"}) + mocker.patch.object(service, "_prepare_encrypted_dict", return_value='{"x-api-key":"enc"}') + + # Act + result = service._process_headers({"x-api-key": "*****"}, provider, "tenant-1") + + # Assert + assert result == '{"x-api-key":"enc"}' + + +def test_process_credentials_should_merge_and_encrypt_credentials( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + provider = _provider_stub() + authentication = MCPAuthentication(client_id="masked-id", client_secret="masked-secret") + mocker.patch.object(service, "_merge_credentials_with_masked", return_value=("plain-id", "plain-secret")) + mocker.patch.object(service, "_build_and_encrypt_credentials", return_value='{"client_information":{}}') + + # Act + result = service._process_credentials(authentication, provider, "tenant-1") + + # Assert + assert result == '{"client_information":{}}' + + +def test_merge_headers_with_masked_should_preserve_original_values_for_unchanged_masked_inputs( + service: MCPToolManageService, +) -> None: + # Arrange + provider = _provider_stub() + incoming_headers = {"x-api-key": "ke***ey", "new-header": "new-value", "dropped": "*****"} + + # Act + result = service._merge_headers_with_masked(incoming_headers, provider) + + # Assert + assert result["x-api-key"] == "key" + assert result["new-header"] == "new-value" + assert result["dropped"] == "*****" + + +def test_merge_credentials_with_masked_should_preserve_decrypted_values_when_masked_match( + service: MCPToolManageService, +) -> None: + # Arrange + provider = _provider_stub() + + # Act + client_id, client_secret = service._merge_credentials_with_masked("pl***id", "pl***et", provider) + + # Assert + assert client_id == "plain-id" + assert client_secret == "plain-secret" + + +def test_build_and_encrypt_credentials_should_encrypt_secret_when_client_secret_present( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch.object( + service, + "_encrypt_dict_fields", + return_value={ + "client_id": "id", + "client_name": "Dify", + "is_dynamic_registration": False, + "encrypted_client_secret": "enc-secret", + }, + ) + + # Act + result = service._build_and_encrypt_credentials("id", "secret", "tenant-1") + + # Assert + payload = json.loads(result) + assert payload["client_information"]["encrypted_client_secret"] == "enc-secret" + + +def test_build_and_encrypt_credentials_should_skip_secret_field_when_client_secret_is_none( + service: MCPToolManageService, + mocker: MockerFixture, +) -> None: + # Arrange + mocker.patch.object( + service, + "_encrypt_dict_fields", + return_value={"client_id": "id", "client_name": "Dify", "is_dynamic_registration": False}, + ) + + # Act + result = service._build_and_encrypt_credentials("id", None, "tenant-1") + + # Assert + payload = json.loads(result) + assert "encrypted_client_secret" not in payload["client_information"] diff --git a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py index ae59da0a3d..e9bcc89445 100644 --- a/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/unit_tests/services/tools/test_workflow_tools_manage_service.py @@ -1,3 +1,9 @@ +""" +Unit tests for services.tools.workflow_tools_manage_service + +Covers WorkflowToolManageService: create, update, list, delete, get, list_single. +""" + import json from types import SimpleNamespace from unittest.mock import MagicMock @@ -9,9 +15,16 @@ from core.tools.errors import WorkflowToolHumanInputNotSupportedError from models.model import App from models.tools import WorkflowToolProvider from services.tools import workflow_tools_manage_service +from services.tools.workflow_tools_manage_service import WorkflowToolManageService + +# --------------------------------------------------------------------------- +# Shared helpers / fake infrastructure +# --------------------------------------------------------------------------- class DummyWorkflow: + """Minimal in-memory Workflow substitute.""" + def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None: self._graph_dict = graph_dict self.version = version @@ -22,72 +35,42 @@ class DummyWorkflow: class FakeQuery: - def __init__(self, result): + """Chainable query object that always returns a fixed result.""" + + def __init__(self, result: object) -> None: self._result = result - def where(self, *args, **kwargs): + def where(self, *args: object, **kwargs: object) -> "FakeQuery": return self - def first(self): + def first(self) -> object: return self._result + def delete(self) -> int: + return 1 + class DummySession: + """Minimal SQLAlchemy session substitute.""" + def __init__(self) -> None: - self.added: list[object] = [] + self.added: list[WorkflowToolProvider] = [] + self.committed: bool = False def __enter__(self) -> "DummySession": return self - def __exit__(self, exc_type, exc, tb) -> bool: + def __exit__(self, exc_type: object, exc: object, tb: object) -> bool: return False - def add(self, obj) -> None: + def add(self, obj: WorkflowToolProvider) -> None: self.added.append(obj) - def begin(self): - return DummyBegin(self) + def begin(self) -> "DummySession": + return self - -class DummyBegin: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionContext: - def __init__(self, session: DummySession) -> None: - self._session = session - - def __enter__(self) -> DummySession: - return self._session - - def __exit__(self, exc_type, exc, tb) -> bool: - return False - - -class DummySessionFactory: - def __init__(self, session: DummySession) -> None: - self._session = session - - def create_session(self) -> DummySessionContext: - return DummySessionContext(self._session) - - -def _build_fake_session(app) -> SimpleNamespace: - def query(model): - if model is WorkflowToolProvider: - return FakeQuery(None) - if model is App: - return FakeQuery(app) - return FakeQuery(None) - - return SimpleNamespace(query=query) + def commit(self) -> None: + self.committed = True def _build_parameters() -> list[WorkflowToolParameterConfiguration]: @@ -96,67 +79,877 @@ def _build_parameters() -> list[WorkflowToolParameterConfiguration]: ] -def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]}) - app = SimpleNamespace(workflow=workflow) +def _build_fake_db( + *, + existing_tool: WorkflowToolProvider | None = None, + app: object | None = None, + tool_by_id: WorkflowToolProvider | None = None, +) -> tuple[MagicMock, DummySession]: + """ + Build a fake db object plus a DummySession for Session context-manager. - fake_session = _build_fake_session(app) - monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + query(WorkflowToolProvider) returns existing_tool on first call, + then tool_by_id on subsequent calls (or None if not provided). + query(App) returns app. + """ + call_counts: dict[str, int] = {"wftp": 0} - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) - mock_invalidate = MagicMock() + def query(model: type) -> FakeQuery: + if model is WorkflowToolProvider: + call_counts["wftp"] += 1 + if call_counts["wftp"] == 1: + return FakeQuery(existing_tool) + return FakeQuery(tool_by_id) + if model is App: + return FakeQuery(app) + return FakeQuery(None) - with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: - workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query, commit=MagicMock()) + dummy_session = DummySession() + return fake_db, dummy_session + + +# --------------------------------------------------------------------------- +# TestCreateWorkflowTool +# --------------------------------------------------------------------------- + + +class TestCreateWorkflowTool: + """Tests for WorkflowToolManageService.create_workflow_tool.""" + + def test_should_raise_when_human_input_nodes_present(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Human-input nodes must be rejected before any provider is created.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "n1", "data": {"type": "human-input"}}]}) + app = SimpleNamespace(workflow=workflow) + fake_session = SimpleNamespace(query=lambda m: FakeQuery(None) if m is WorkflowToolProvider else FakeQuery(app)) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + mock_from_db = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + + # Act + Assert + with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info: + WorkflowToolManageService.create_workflow_tool( + user_id="user-id", + tenant_id="tenant-id", + workflow_app_id="app-id", + name="tool_name", + label="Tool", + icon={"type": "emoji", "emoji": "🔧"}, + description="desc", + parameters=_build_parameters(), + ) + + assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" + mock_from_db.assert_not_called() + + def test_should_raise_when_duplicate_name_or_app_id(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Existing provider with same name or app_id raises ValueError.""" + # Arrange + existing = MagicMock(spec=WorkflowToolProvider) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(existing)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-1", + name="dup", + label="Dup", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the referenced App does not exist.""" + # Arrange + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(None) # App returns None + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="missing-app", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the App has no attached Workflow.""" + # Arrange + app_no_workflow = SimpleNamespace(workflow=None) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app_no_workflow) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions from WorkflowToolProviderController.from_db are wrapped as ValueError.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr( + workflow_tools_manage_service.WorkflowToolProviderController, + "from_db", + MagicMock(side_effect=RuntimeError("bad config")), + ) + + # Act + Assert + with pytest.raises(ValueError, match="bad config"): + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_succeed_and_persist_provider(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: provider is added to session and success dict is returned.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="2.0.0") + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + + icon = {"type": "emoji", "emoji": "🔧"} + + # Act + result = WorkflowToolManageService.create_workflow_tool( user_id="user-id", tenant_id="tenant-id", workflow_app_id="app-id", name="tool_name", label="Tool", - icon={"type": "emoji", "emoji": "tool"}, + icon=icon, description="desc", parameters=_build_parameters(), ) - assert exc_info.value.error_code == "workflow_tool_human_input_not_supported" - mock_from_db.assert_not_called() - mock_invalidate.assert_not_called() + # Assert + assert result == {"result": "success"} + assert len(dummy_session.added) == 1 + created: WorkflowToolProvider = dummy_session.added[0] + assert created.name == "tool_name" + assert created.label == "Tool" + assert created.icon == json.dumps(icon) + assert created.version == "2.0.0" + + def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Labels are forwarded to ToolLabelManager when provided.""" + # Arrange + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + + def query(m: type) -> FakeQuery: + if m is WorkflowToolProvider: + return FakeQuery(None) + return FakeQuery(app) + + fake_db = MagicMock() + fake_db.session = SimpleNamespace(query=query) + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + dummy_session = DummySession() + monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + mock_label_mgr = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) + mock_to_ctrl = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", mock_to_ctrl + ) + + # Act + WorkflowToolManageService.create_workflow_tool( + user_id="u", + tenant_id="t", + workflow_app_id="app-id", + name="n", + label="L", + icon={}, + description="", + parameters=[], + labels=["tag1", "tag2"], + ) + + # Assert + mock_label_mgr.assert_called_once() -def test_create_workflow_tool_success(monkeypatch): - workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]}) - app = SimpleNamespace(workflow=workflow) +# --------------------------------------------------------------------------- +# TestUpdateWorkflowTool +# --------------------------------------------------------------------------- - fake_db = MagicMock() - fake_session = _build_fake_session(app) - fake_db.session = fake_session - monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) - dummy_session = DummySession() - monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session) +class TestUpdateWorkflowTool: + """Tests for WorkflowToolManageService.update_workflow_tool.""" - mock_from_db = MagicMock() - monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db) + def _make_provider(self) -> WorkflowToolProvider: + p = MagicMock(spec=WorkflowToolProvider) + p.app_id = "app-id" + p.tenant_id = "tenant-id" + return p - icon = {"type": "emoji", "emoji": "tool"} + def test_should_raise_when_name_duplicated(self, monkeypatch: pytest.MonkeyPatch) -> None: + """If another tool with the given name already exists, raise ValueError.""" + # Arrange + existing = MagicMock(spec=WorkflowToolProvider) - result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool( - user_id="user-id", - tenant_id="tenant-id", - workflow_app_id="app-id", - name="tool_name", - label="Tool", - icon=icon, - description="desc", - parameters=_build_parameters(), - ) + def query(m: type) -> FakeQuery: + return FakeQuery(existing) - assert result == {"result": "success"} - assert len(dummy_session.added) == 1 - created_provider = dummy_session.added[0] - assert created_provider.name == "tool_name" - assert created_provider.label == "Tool" - assert created_provider.icon == json.dumps(icon) - assert created_provider.version == workflow.version - mock_from_db.assert_called_once() + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="already exists"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="dup", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the workflow tool to update does not exist.""" + # Arrange + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + # 1st call: name uniqueness check → None (no duplicate) + # 2nd call: fetch tool by id → None (not found) + return FakeQuery(None) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="missing", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the tool's referenced App has been removed.""" + # Arrange + provider = self._make_provider() + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + # 1st: duplicate name check (None), 2nd: fetch provider + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(None) # App not found + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the App exists but has no Workflow.""" + # Arrange + provider = self._make_provider() + app_no_wf = SimpleNamespace(workflow=None) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app_no_wf) + + monkeypatch.setattr(workflow_tools_manage_service.db, "session", SimpleNamespace(query=query)) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_raise_when_from_db_fails(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Exceptions from from_db are re-raised as ValueError.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=MagicMock()), + ) + monkeypatch.setattr( + workflow_tools_manage_service.WorkflowToolProviderController, + "from_db", + MagicMock(side_effect=RuntimeError("from_db error")), + ) + + # Act + Assert + with pytest.raises(ValueError, match="from_db error"): + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + ) + + def test_should_succeed_and_call_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: provider fields are updated and session committed.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="3.0.0") + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + mock_commit = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=mock_commit), + ) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + + icon = {"type": "emoji", "emoji": "🛠"} + + # Act + result = WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="new_name", + label="New Label", + icon=icon, + description="new desc", + parameters=_build_parameters(), + ) + + # Assert + assert result == {"result": "success"} + mock_commit.assert_called_once() + assert provider.name == "new_name" + assert provider.label == "New Label" + assert provider.icon == json.dumps(icon) + assert provider.version == "3.0.0" + + def test_should_call_label_manager_when_labels_provided(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Labels are forwarded to ToolLabelManager during update.""" + # Arrange + provider = self._make_provider() + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + call_count = {"n": 0} + + def query(m: type) -> FakeQuery: + call_count["n"] += 1 + if m is WorkflowToolProvider: + return FakeQuery(None) if call_count["n"] == 1 else FakeQuery(provider) + return FakeQuery(app) + + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=query, commit=MagicMock()), + ) + monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", MagicMock()) + mock_label_mgr = MagicMock() + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "update_tool_labels", mock_label_mgr) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", MagicMock() + ) + + # Act + WorkflowToolManageService.update_workflow_tool( + user_id="u", + tenant_id="t", + workflow_tool_id="tool-1", + name="n", + label="L", + icon={}, + description="", + parameters=[], + labels=["a"], + ) + + # Assert + mock_label_mgr.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestListTenantWorkflowTools +# --------------------------------------------------------------------------- + + +class TestListTenantWorkflowTools: + """Tests for WorkflowToolManageService.list_tenant_workflow_tools.""" + + def test_should_return_empty_list_when_no_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """An empty database yields an empty result list.""" + # Arrange + fake_scalars = MagicMock() + fake_scalars.all.return_value = [] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert + assert result == [] + + def test_should_skip_broken_providers_and_log(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Providers that fail to load are logged and skipped.""" + # Arrange + good_provider = MagicMock(spec=WorkflowToolProvider) + good_provider.id = "good-id" + good_provider.app_id = "app-good" + bad_provider = MagicMock(spec=WorkflowToolProvider) + bad_provider.id = "bad-id" + bad_provider.app_id = "app-bad" + + fake_scalars = MagicMock() + fake_scalars.all.return_value = [good_provider, bad_provider] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + good_ctrl = MagicMock() + good_ctrl.provider_id = "good-id" + + def to_controller(provider: WorkflowToolProvider) -> MagicMock: + if provider is bad_provider: + raise RuntimeError("broken provider") + return good_ctrl + + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_controller", to_controller + ) + mock_get_labels = MagicMock(return_value={}) + monkeypatch.setattr(workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", mock_get_labels) + mock_to_user = MagicMock() + mock_to_user.return_value.tools = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "workflow_provider_to_user_provider", mock_to_user + ) + monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) + mock_get_tools = MagicMock(return_value=[MagicMock()]) + good_ctrl.get_tools = mock_get_tools + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() + ) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert - only good provider contributed + assert len(result) == 1 + + def test_should_return_tools_for_all_providers(self, monkeypatch: pytest.MonkeyPatch) -> None: + """All successfully loaded providers appear in the result.""" + # Arrange + provider = MagicMock(spec=WorkflowToolProvider) + provider.id = "p-1" + provider.app_id = "app-1" + + fake_scalars = MagicMock() + fake_scalars.all.return_value = [provider] + fake_db = MagicMock() + fake_db.session.scalars.return_value = fake_scalars + monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db) + + ctrl = MagicMock() + ctrl.provider_id = "p-1" + ctrl.get_tools.return_value = [MagicMock()] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tools_labels", MagicMock(return_value={"p-1": []}) + ) + user_provider = MagicMock() + user_provider.tools = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_user_provider", + MagicMock(return_value=user_provider), + ) + monkeypatch.setattr(workflow_tools_manage_service.ToolTransformService, "repack_provider", MagicMock()) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", MagicMock() + ) + + # Act + result = WorkflowToolManageService.list_tenant_workflow_tools("u", "t") + + # Assert + assert len(result) == 1 + assert result[0] is user_provider + + +# --------------------------------------------------------------------------- +# TestDeleteWorkflowTool +# --------------------------------------------------------------------------- + + +class TestDeleteWorkflowTool: + """Tests for WorkflowToolManageService.delete_workflow_tool.""" + + def test_should_delete_and_commit(self, monkeypatch: pytest.MonkeyPatch) -> None: + """delete_workflow_tool queries, deletes, commits, and returns success.""" + # Arrange + mock_query = MagicMock() + mock_query.where.return_value.delete.return_value = 1 + mock_commit = MagicMock() + fake_session = SimpleNamespace(query=lambda m: mock_query, commit=mock_commit) + monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session) + + # Act + result = WorkflowToolManageService.delete_workflow_tool("u", "t", "tool-1") + + # Assert + assert result == {"result": "success"} + mock_commit.assert_called_once() + + +# --------------------------------------------------------------------------- +# TestGetWorkflowToolByToolId / ByAppId +# --------------------------------------------------------------------------- + + +class TestGetWorkflowToolByToolIdAndAppId: + """Tests for get_workflow_tool_by_tool_id and get_workflow_tool_by_app_id.""" + + def test_get_by_tool_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises ValueError when no WorkflowToolProvider found by tool id.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_tool_id("u", "t", "missing") + + def test_get_by_app_id_should_raise_when_db_tool_is_none(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Raises ValueError when no WorkflowToolProvider found by app id.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService.get_workflow_tool_by_app_id("u", "t", "missing-app") + + +# --------------------------------------------------------------------------- +# TestGetWorkflowTool (private _get_workflow_tool) +# --------------------------------------------------------------------------- + + +class TestGetWorkflowTool: + """Tests for the internal _get_workflow_tool helper.""" + + def test_should_raise_when_db_tool_none(self) -> None: + """_get_workflow_tool raises ValueError when db_tool is None.""" + with pytest.raises(ValueError, match="Tool not found"): + WorkflowToolManageService._get_workflow_tool("t", None) + + def test_should_raise_when_app_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the corresponding App row is missing.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_raise_when_workflow_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when App has no attached Workflow.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + app = SimpleNamespace(workflow=None) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="Workflow not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the controller returns no WorkflowTool instances.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + db_tool.id = "tool-1" + workflow = DummyWorkflow(graph_dict={"nodes": []}) + app = SimpleNamespace(workflow=workflow) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + ctrl = MagicMock() + ctrl.get_tools.return_value = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService._get_workflow_tool("t", db_tool) + + def test_should_return_dict_on_success(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: returns a dict with name, label, icon, synced, etc.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.app_id = "app-1" + db_tool.tenant_id = "t" + db_tool.id = "tool-1" + db_tool.name = "my_tool" + db_tool.label = "My Tool" + db_tool.icon = json.dumps({"emoji": "🔧"}) + db_tool.description = "some desc" + db_tool.privacy_policy = "" + db_tool.version = "1.0" + db_tool.parameter_configurations = [] + workflow = DummyWorkflow(graph_dict={"nodes": []}, version="1.0") + app = SimpleNamespace(workflow=workflow) + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(app)), + ) + + workflow_tool = MagicMock() + workflow_tool.entity.output_schema = {"type": "object"} + ctrl = MagicMock() + ctrl.get_tools.return_value = [workflow_tool] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + mock_convert = MagicMock(return_value={"tool": "api_entity"}) + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, "convert_tool_entity_to_api_entity", mock_convert + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) + ) + + # Act + result = WorkflowToolManageService._get_workflow_tool("t", db_tool) + + # Assert + assert result["name"] == "my_tool" + assert result["label"] == "My Tool" + assert result["synced"] is True + assert "icon" in result + assert "output_schema" in result + + +# --------------------------------------------------------------------------- +# TestListSingleWorkflowTools +# --------------------------------------------------------------------------- + + +class TestListSingleWorkflowTools: + """Tests for WorkflowToolManageService.list_single_workflow_tools.""" + + def test_should_raise_when_tool_not_found(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the specified tool does not exist in DB.""" + # Arrange + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(None)), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + def test_should_raise_when_no_workflow_tools(self, monkeypatch: pytest.MonkeyPatch) -> None: + """ValueError when the controller yields no tools for the provider.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.id = "tool-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(db_tool)), + ) + ctrl = MagicMock() + ctrl.get_tools.return_value = [] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + + # Act + Assert + with pytest.raises(ValueError, match="not found"): + WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + def test_should_return_api_entity_list(self, monkeypatch: pytest.MonkeyPatch) -> None: + """Happy path: returns list with one ToolApiEntity.""" + # Arrange + db_tool = MagicMock(spec=WorkflowToolProvider) + db_tool.id = "tool-1" + db_tool.tenant_id = "t" + monkeypatch.setattr( + workflow_tools_manage_service.db, + "session", + SimpleNamespace(query=lambda m: FakeQuery(db_tool)), + ) + workflow_tool = MagicMock() + ctrl = MagicMock() + ctrl.get_tools.return_value = [workflow_tool] + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "workflow_provider_to_controller", + MagicMock(return_value=ctrl), + ) + api_entity = MagicMock() + monkeypatch.setattr( + workflow_tools_manage_service.ToolTransformService, + "convert_tool_entity_to_api_entity", + MagicMock(return_value=api_entity), + ) + monkeypatch.setattr( + workflow_tools_manage_service.ToolLabelManager, "get_tool_labels", MagicMock(return_value=[]) + ) + + # Act + result = WorkflowToolManageService.list_single_workflow_tools("u", "t", "tool-1") + + # Assert + assert result == [api_entity]