From 8906ab8e5210ab0739e13b7b7224855ad6fd8b77 Mon Sep 17 00:00:00 2001 From: rajatagarwal-oss Date: Mon, 9 Mar 2026 14:37:13 +0530 Subject: [PATCH 1/3] test: unit test cases for console.datasets module (#32179) Co-authored-by: akashseth-ifp --- api/controllers/console/datasets/datasets.py | 2 +- .../console/datasets/rag_pipeline/__init__.py | 0 .../rag_pipeline/test_datasource_auth.py | 817 +++++++ .../test_datasource_content_preview.py | 143 ++ .../rag_pipeline/test_rag_pipeline.py | 187 ++ .../test_rag_pipeline_datasets.py | 187 ++ .../test_rag_pipeline_draft_variable.py | 324 +++ .../rag_pipeline/test_rag_pipeline_import.py | 329 +++ .../test_rag_pipeline_workflow.py | 688 ++++++ .../console/datasets/test_data_source.py | 444 ++++ .../console/datasets/test_datasets.py | 1926 +++++++++++++++++ .../datasets/test_datasets_document.py | 1379 ++++++++++++ .../datasets/test_datasets_segments.py | 1252 +++++++++++ .../console/datasets/test_external.py | 399 ++++ .../console/datasets/test_hit_testing.py | 160 ++ .../console/datasets/test_hit_testing_base.py | 207 ++ .../console/datasets/test_metadata.py | 362 ++++ .../console/datasets/test_website.py | 233 ++ .../console/datasets/test_wraps.py | 117 + 19 files changed, 9155 insertions(+), 1 deletion(-) create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_data_source.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_datasets.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_external.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_metadata.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_website.py create mode 100644 api/tests/unit_tests/controllers/console/datasets/test_wraps.py diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 54303b2482..ddad7f40ca 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -807,7 +807,7 @@ class DatasetApiKeyApi(Resource): console_ns.abort( 400, message=f"Cannot create more than {self.max_keys} API keys for this resource type.", - code="max_keys_exceeded", + custom="max_keys_exceeded", ) key = ApiToken.generate_api_key(self.token_prefix, 24) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py new file mode 100644 index 0000000000..9014edc39e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -0,0 +1,817 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_auth import ( + DatasourceAuth, + DatasourceAuthDefaultApi, + DatasourceAuthDeleteApi, + DatasourceAuthListApi, + DatasourceAuthOauthCustomClient, + DatasourceAuthUpdateApi, + DatasourceHardCodeAuthListApi, + DatasourceOAuthCallback, + DatasourcePluginOAuthAuthorizationUrl, + DatasourceUpdateProviderNameApi, +) +from core.plugin.impl.oauth import OAuthHandler +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError +from services.datasource_provider_service import DatasourceProviderService +from services.plugin.oauth_service import OAuthProxyService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasourcePluginOAuthAuthorizationUrl: + def test_get_success(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/?credential_id=cred-1"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-1", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + + def test_get_no_oauth_config(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_without_credential_id_sets_cookie(self, app): + api = DatasourcePluginOAuthAuthorizationUrl() + method = unwrap(api.get) + + user = MagicMock(id="user-1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthProxyService, + "create_proxy_context", + return_value="ctx-123", + ), + patch.object( + OAuthHandler, + "get_authorization_url", + return_value={"url": "http://auth"}, + ), + ): + response = method(api, "notion") + + assert response.status_code == 200 + assert "context_id" in response.headers.get("Set-Cookie") + + +class TestDatasourceOAuthCallback: + def test_callback_success_new_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {"name": "test"} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + def test_callback_missing_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_invalid_context(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + with ( + app.test_request_context("/?context_id=bad"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "notion") + + def test_callback_oauth_config_not_found(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + context = {"user_id": "u", "tenant_id": "t"} + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "notion") + + def test_callback_reauthorize_existing_credential(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} # avatar + name missing + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": "cred-1", + } + + with ( + app.test_request_context("/?context_id=ctx"), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "reauthorize_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + assert "/oauth-callback" in response.location + + def test_callback_context_id_from_cookie(self, app): + api = DatasourceOAuthCallback() + method = unwrap(api.get) + + oauth_response = MagicMock() + oauth_response.credentials = {"token": "abc"} + oauth_response.expires_at = None + oauth_response.metadata = {} + + context = { + "user_id": "user-1", + "tenant_id": "tenant-1", + "credential_id": None, + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch.object( + OAuthProxyService, + "use_proxy_context", + return_value=context, + ), + patch.object( + DatasourceProviderService, + "get_oauth_client", + return_value={"client_id": "abc"}, + ), + patch.object( + OAuthHandler, + "get_credentials", + return_value=oauth_response, + ), + patch.object( + DatasourceProviderService, + "add_datasource_oauth_provider", + return_value=None, + ), + ): + response = method(api, "notion") + + assert response.status_code == 302 + + +class TestDatasourceAuth: + def test_post_success(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "val"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_invalid_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {"credentials": {"key": "bad"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "add_datasource_api_key_provider", + side_effect=CredentialsValidateFailedError("invalid"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_success(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] + + def test_post_missing_credentials(self, app): + api = DatasourceAuth() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_get_empty_list(self, app): + api = DatasourceAuth() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "list_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api, "notion") + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceAuthDeleteApi: + def test_delete_success(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {"credential_id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_missing_credential_id(self, app): + api = DatasourceAuthDeleteApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceAuthUpdateApi: + def test_update_success(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {"k": "v"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 201 + + def test_update_with_credentials_none(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + response, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + def test_update_name_only(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 201 + + def test_update_with_empty_credentials_dict(self, app): + api = DatasourceAuthUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "credentials": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_credentials", + return_value=None, + ) as update_mock, + ): + _, status = method(api, "notion") + + update_mock.assert_called_once() + assert status == 201 + + +class TestDatasourceAuthListApi: + def test_list_success(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + def test_auth_list_empty(self, app): + api = DatasourceAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_all_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + def test_hardcode_list_empty(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[], + ), + ): + response, status = method(api) + + assert status == 200 + assert response["result"] == [] + + +class TestDatasourceHardCodeAuthListApi: + def test_list_success(self, app): + api = DatasourceHardCodeAuthListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "get_hard_code_datasource_credentials", + return_value=[{"id": "1"}], + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDatasourceAuthOauthCustomClient: + def test_post_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {"client_params": {}, "enable_oauth_custom_client": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_delete_success(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "remove_oauth_custom_client_params", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_post_empty_payload(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ), + ): + _, status = method(api, "notion") + + assert status == 200 + + def test_post_disabled_flag(self, app): + api = DatasourceAuthOauthCustomClient() + method = unwrap(api.post) + + payload = { + "client_params": {"a": 1}, + "enable_oauth_custom_client": False, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "setup_oauth_custom_client_params", + return_value=None, + ) as setup_mock, + ): + _, status = method(api, "notion") + + setup_mock.assert_called_once() + assert status == 200 + + +class TestDatasourceAuthDefaultApi: + def test_set_default_success(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {"id": "cred-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "set_default_datasource_provider", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_default_missing_id(self, app): + api = DatasourceAuthDefaultApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + +class TestDatasourceUpdateProviderNameApi: + def test_update_name_success(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"credential_id": "id", "name": "New Name"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + DatasourceProviderService, + "update_datasource_provider_name", + return_value=None, + ), + ): + response, status = method(api, "notion") + + assert status == 200 + + def test_update_name_too_long(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = { + "credential_id": "id", + "name": "x" * 101, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") + + def test_update_name_missing_credential_id(self, app): + api = DatasourceUpdateProviderNameApi() + method = unwrap(api.post) + + payload = {"name": "Valid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_auth.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api, "notion") diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py new file mode 100644 index 0000000000..7a8ccde55a --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_content_preview.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.datasource_content_preview import ( + DataSourceContentPreviewApi, +) +from models import Account +from models.dataset import Pipeline + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDataSourceContentPreviewApi: + def _valid_payload(self): + return { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": "cred-1", + } + + def test_post_success(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + node_id = "node-1" + account = MagicMock(spec=Account) + + preview_result = {"content": "preview data"} + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = preview_result + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, node_id) + + service_instance.run_datasource_node_preview.assert_called_once_with( + pipeline=pipeline, + node_id=node_id, + user_inputs=payload["inputs"], + account=account, + datasource_type=payload["datasource_type"], + is_published=True, + credential_id=payload["credential_id"], + ) + assert status == 200 + assert response == preview_result + + def test_post_forbidden_non_account_user(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = self._valid_payload() + + pipeline = MagicMock(spec=Pipeline) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + MagicMock(), # NOT Account + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline, "node-1") + + def test_post_invalid_payload(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + # datasource_type missing + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node-1") + + def test_post_without_credential_id(self, app): + api = DataSourceContentPreviewApi() + method = unwrap(api.post) + + payload = { + "inputs": {"query": "hello"}, + "datasource_type": "notion", + "credential_id": None, + } + + pipeline = MagicMock(spec=Pipeline) + account = MagicMock(spec=Account) + + service_instance = MagicMock() + service_instance.run_datasource_node_preview.return_value = {"ok": True} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.current_user", + account, + ), + patch( + "controllers.console.datasets.rag_pipeline.datasource_content_preview.RagPipelineService", + return_value=service_instance, + ), + ): + response, status = method(api, pipeline, "node-1") + + service_instance.run_datasource_node_preview.assert_called_once() + assert status == 200 + assert response == {"ok": True} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py new file mode 100644 index 0000000000..3b8679f4ec --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline.py @@ -0,0 +1,187 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline import ( + CustomizedPipelineTemplateApi, + PipelineTemplateDetailApi, + PipelineTemplateListApi, + PublishCustomizedPipelineTemplateApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestPipelineTemplateListApi: + def test_get_success(self, app): + api = PipelineTemplateListApi() + method = unwrap(api.get) + + templates = [{"id": "t1"}] + + with ( + app.test_request_context("/?type=built-in&language=en-US"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.get_pipeline_templates", + return_value=templates, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == templates + + +class TestPipelineTemplateDetailApi: + def test_get_success(self, app): + api = PipelineTemplateDetailApi() + method = unwrap(api.get) + + template = {"id": "tpl-1"} + + service = MagicMock() + service.get_pipeline_template_detail.return_value = template + + with ( + app.test_request_context("/?type=built-in"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == template + + +class TestCustomizedPipelineTemplateApi: + def test_patch_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.patch) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.update_customized_pipeline_template" + ) as update_mock, + ): + response = method(api, "tpl-1") + + update_mock.assert_called_once() + assert response == 200 + + def test_delete_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService.delete_customized_pipeline_template" + ) as delete_mock, + ): + response = method(api, "tpl-1") + + delete_mock.assert_called_once_with("tpl-1") + assert response == 200 + + def test_post_success(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + template = MagicMock() + template.yaml_content = "yaml-data" + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = template + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + response, status = method(api, "tpl-1") + + assert status == 200 + assert response == {"data": "yaml-data"} + + def test_post_template_not_found(self, app): + api = CustomizedPipelineTemplateApi() + method = unwrap(api.post) + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session = MagicMock() + session.query.return_value.where.return_value.first.return_value = None + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.Session", + return_value=session_ctx, + ), + ): + with pytest.raises(ValueError): + method(api, "tpl-1") + + +class TestPublishCustomizedPipelineTemplateApi: + def test_post_success(self, app): + api = PublishCustomizedPipelineTemplateApi() + method = unwrap(api.post) + + payload = { + "name": "Template", + "description": "Desc", + "icon_info": {"icon": "📘"}, + } + + service = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline.RagPipelineService", + return_value=service, + ), + ): + response = method(api, "pipeline-1") + + service.publish_customized_pipeline_template.assert_called_once() + assert response == {"result": "success"} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py new file mode 100644 index 0000000000..fd38fcbb5e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_datasets.py @@ -0,0 +1,187 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.rag_pipeline.rag_pipeline_datasets import ( + CreateEmptyRagPipelineDatasetApi, + CreateRagPipelineDatasetApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestCreateRagPipelineDatasetApi: + def _valid_payload(self): + return {"yaml_content": "name: test"} + + def test_post_success(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + import_info = {"dataset_id": "ds-1"} + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.return_value = import_info + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == import_info + + def test_post_forbidden_non_editor(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_dataset_name_duplicate(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = self._valid_payload() + user = MagicMock(is_dataset_editor=True) + + mock_service = MagicMock() + mock_service.create_rag_pipeline_dataset.side_effect = services.errors.dataset.DatasetNameDuplicateError() + + mock_session_ctx = MagicMock() + mock_session_ctx.__enter__.return_value = MagicMock() + mock_session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.Session", + return_value=mock_session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.RagPipelineDslService", + return_value=mock_service, + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload(self, app): + api = CreateRagPipelineDatasetApi() + method = unwrap(api.post) + + payload = {} + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestCreateEmptyRagPipelineDatasetApi: + def test_post_success(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.DatasetService.create_empty_rag_pipeline_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.marshal", + return_value={"id": "ds-1"}, + ), + ): + response, status = method(api) + + assert status == 201 + assert response == {"id": "ds-1"} + + def test_post_forbidden_non_editor(self, app): + api = CreateEmptyRagPipelineDatasetApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py new file mode 100644 index 0000000000..b4c0903f63 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -0,0 +1,324 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Response + +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist +from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable import ( + RagPipelineEnvironmentVariableCollectionApi, + RagPipelineNodeVariableCollectionApi, + RagPipelineSystemVariableCollectionApi, + RagPipelineVariableApi, + RagPipelineVariableCollectionApi, + RagPipelineVariableResetApi, +) +from controllers.web.error import InvalidArgumentError, NotFoundError +from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID +from dify_graph.variables.types import SegmentType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def fake_db(): + db = MagicMock() + db.engine = MagicMock() + db.session.return_value = MagicMock() + return db + + +@pytest.fixture +def editor_user(): + user = MagicMock(spec=Account) + user.has_edit_permission = True + return user + + +@pytest.fixture +def restx_config(app): + return patch.dict(app.config, {"RESTX_MASK_HEADER": "X-Fields"}) + + +class TestRagPipelineVariableCollectionApi: + def test_get_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = True + + # IMPORTANT: RESTX expects .variables + var_list = MagicMock() + var_list.variables = [] + + draft_srv = MagicMock() + draft_srv.list_variables_without_values.return_value = var_list + + with ( + app.test_request_context("/?page=1&limit=10"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=draft_srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_get_variables_workflow_not_exist(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + rag_srv = MagicMock() + rag_srv.is_workflow_exist.return_value = False + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_delete_variables_success(self, app, fake_db, editor_user): + api = RagPipelineVariableCollectionApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService"), + ): + result = method(api, pipeline) + + assert isinstance(result, Response) + assert result.status_code == 204 + + +class TestRagPipelineNodeVariableCollectionApi: + def test_get_node_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_node_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "node1") + + assert result["items"] == [] + + def test_get_node_variables_invalid_node(self, app, editor_user): + api = RagPipelineNodeVariableCollectionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + ): + with pytest.raises(InvalidArgumentError): + method(api, MagicMock(), SYSTEM_VARIABLE_NODE_ID) + + +class TestRagPipelineVariableApi: + def test_get_variable_not_found(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.get) + + srv = MagicMock() + srv.get_variable.return_value = None + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(NotFoundError): + method(api, MagicMock(), "v1") + + def test_patch_variable_invalid_file_payload(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.patch) + + pipeline = MagicMock(id="p1", tenant_id="t1") + variable = MagicMock(app_id="p1", value_type=SegmentType.FILE) + + srv = MagicMock() + srv.get_variable.return_value = variable + + payload = {"value": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + with pytest.raises(InvalidArgumentError): + method(api, pipeline, "v1") + + def test_delete_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableApi() + method = unwrap(api.delete) + + pipeline = MagicMock(id="p1") + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline, "v1") + + assert result.status_code == 204 + + +class TestRagPipelineVariableResetApi: + def test_reset_variable_success(self, app, fake_db, editor_user): + api = RagPipelineVariableResetApi() + method = unwrap(api.put) + + pipeline = MagicMock(id="p1") + workflow = MagicMock() + variable = MagicMock(app_id="p1") + + srv = MagicMock() + srv.get_variable.return_value = variable + srv.reset_variable.return_value = variable + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.marshal", + return_value={"id": "v1"}, + ), + ): + result = method(api, pipeline, "v1") + + assert result == {"id": "v1"} + + +class TestSystemAndEnvironmentVariablesApi: + def test_system_variables_success(self, app, fake_db, editor_user, restx_config): + api = RagPipelineSystemVariableCollectionApi() + method = unwrap(api.get) + + pipeline = MagicMock(id="p1") + + var_list = MagicMock() + var_list.variables = [] + + srv = MagicMock() + srv.list_system_variables.return_value = var_list + + with ( + app.test_request_context("/"), + restx_config, + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.db", fake_db), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.WorkflowDraftVariableService", + return_value=srv, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [] + + def test_environment_variables_success(self, app, editor_user): + api = RagPipelineEnvironmentVariableCollectionApi() + method = unwrap(api.get) + + env_var = MagicMock( + id="e1", + name="ENV", + description="d", + selector="s", + value_type=MagicMock(value="string"), + value="x", + ) + + workflow = MagicMock(environment_variables=[env_var]) + pipeline = MagicMock(id="p1") + + rag_srv = MagicMock() + rag_srv.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.current_user", editor_user), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable.RagPipelineService", + return_value=rag_srv, + ), + ): + result = method(api, pipeline) + + assert len(result["items"]) == 1 diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py new file mode 100644 index 0000000000..a72ad45110 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_import.py @@ -0,0 +1,329 @@ +from unittest.mock import MagicMock, patch + +from controllers.console import console_ns +from controllers.console.datasets.rag_pipeline.rag_pipeline_import import ( + RagPipelineExportApi, + RagPipelineImportApi, + RagPipelineImportCheckDependenciesApi, + RagPipelineImportConfirmApi, +) +from models.dataset import Pipeline +from services.app_dsl_service import ImportStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestRagPipelineImportApi: + def _payload(self, mode="create"): + return { + "mode": mode, + "yaml_content": "content", + "name": "Test", + } + + def test_post_success_200(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"status": "success"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"status": "success"} + + def test_post_failed_400(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"status": "failed"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 400 + assert response == {"status": "failed"} + + def test_post_pending_202(self, app): + api = RagPipelineImportApi() + method = unwrap(api.post) + + payload = self._payload() + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.PENDING + result.model_dump.return_value = {"status": "pending"} + + service = MagicMock() + service.import_rag_pipeline.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api) + + assert status == 202 + assert response == {"status": "pending"} + + +class TestRagPipelineImportConfirmApi: + def test_confirm_success(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = "completed" + result.model_dump.return_value = {"ok": True} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 200 + assert response == {"ok": True} + + def test_confirm_failed(self, app): + api = RagPipelineImportConfirmApi() + method = unwrap(api.post) + + user = MagicMock() + result = MagicMock() + result.status = ImportStatus.FAILED + result.model_dump.return_value = {"ok": False} + + service = MagicMock() + service.confirm_import.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, "import-1") + + assert status == 400 + assert response == {"ok": False} + + +class TestRagPipelineImportCheckDependenciesApi: + def test_get_success(self, app): + api = RagPipelineImportCheckDependenciesApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + result = MagicMock() + result.model_dump.return_value = {"deps": []} + + service = MagicMock() + service.check_dependencies.return_value = result + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"deps": []} + + +class TestRagPipelineExportApi: + def test_get_with_include_secret(self, app): + api = RagPipelineExportApi() + method = unwrap(api.get) + + pipeline = MagicMock(spec=Pipeline) + service = MagicMock() + service.export_rag_pipeline_dsl.return_value = {"yaml": "data"} + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = MagicMock() + session_ctx.__exit__.return_value = None + + with ( + app.test_request_context("/?include_secret=true"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_import.RagPipelineDslService", + return_value=service, + ), + ): + response, status = method(api, pipeline) + + assert status == 200 + assert response == {"data": {"yaml": "data"}} diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py new file mode 100644 index 0000000000..7775cbdd81 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -0,0 +1,688 @@ +from datetime import datetime +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync +from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( + DefaultRagPipelineBlockConfigApi, + DraftRagPipelineApi, + DraftRagPipelineRunApi, + PublishedAllRagPipelineApi, + PublishedRagPipelineApi, + PublishedRagPipelineRunApi, + RagPipelineByIdApi, + RagPipelineDatasourceVariableApi, + RagPipelineDraftNodeRunApi, + RagPipelineDraftRunIterationNodeApi, + RagPipelineDraftRunLoopNodeApi, + RagPipelineRecommendedPluginApi, + RagPipelineTaskStopApi, + RagPipelineTransformApi, + RagPipelineWorkflowLastRunApi, +) +from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from services.errors.app import WorkflowHashNotEqualError +from services.errors.llm import InvokeRateLimitError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDraftWorkflowApi: + def test_get_draft_success(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result == workflow + + def test_get_draft_not_exist(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotExist): + method(api, pipeline) + + def test_sync_hash_not_match(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.sync_draft_workflow.side_effect = WorkflowHashNotEqualError() + + with ( + app.test_request_context("/", json={"graph": {}, "features": {}}), + patch.object(type(console_ns), "payload", {"graph": {}, "features": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(DraftWorkflowNotSync): + method(api, pipeline) + + def test_sync_invalid_text_plain(self, app): + api = DraftRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", data="bad-json", headers={"Content-Type": "text/plain"}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + response, status = method(api, pipeline) + assert status == 400 + + +class TestDraftRunNodes: + def test_iteration_node_success(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline, "node") + assert result == {"ok": True} + + def test_iteration_node_conversation_not_exists(self, app): + api = RagPipelineDraftRunIterationNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_iteration", + side_effect=services.errors.conversation.ConversationNotExistsError(), + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node") + + def test_loop_node_success(self, app): + api = RagPipelineDraftRunLoopNodeApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate_single_loop", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline, "node") == {"ok": True} + + +class TestPipelineRunApis: + def test_draft_run_success(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + assert method(api, pipeline) == {"ok": True} + + def test_draft_run_rate_limit(self, app): + api = DraftRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context( + "/", json={"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"} + ), + patch.object( + type(console_ns), + "payload", + {"inputs": {}, "datasource_type": "x", "datasource_info_list": [], "start_node_id": "n"}, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDraftNodeRun: + def test_execution_not_found(self, app): + api = RagPipelineDraftNodeRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + service = MagicMock() + service.run_draft_workflow_node.return_value = None + + with ( + app.test_request_context("/", json={"inputs": {}}), + patch.object(type(console_ns), "payload", {"inputs": {}}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(ValueError): + method(api, pipeline, "node") + + +class TestPublishedPipelineApis: + def test_publish_success(self, app): + api = PublishedRagPipelineApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + workflow = MagicMock( + id="w1", + created_at=datetime.utcnow(), + ) + + session = MagicMock() + session.merge.return_value = pipeline + + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + service = MagicMock() + service.publish_workflow.return_value = workflow + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["result"] == "success" + assert "created_at" in result + + +class TestMiscApis: + def test_task_stop(self, app): + api = RagPipelineTaskStopApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.AppQueueManager.set_stop_flag" + ) as stop_mock, + ): + result = method(api, pipeline, "task-1") + stop_mock.assert_called_once() + assert result["result"] == "success" + + def test_transform_forbidden(self, app): + api = RagPipelineTransformApi() + method = unwrap(api.post) + + user = MagicMock(has_edit_permission=False, is_dataset_operator=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds1") + + def test_recommended_plugins(self, app): + api = RagPipelineRecommendedPluginApi() + method = unwrap(api.get) + + service = MagicMock() + service.get_recommended_plugins.return_value = [{"id": "p1"}] + + with ( + app.test_request_context("/?type=all"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api) + assert result == [{"id": "p1"}] + + +class TestPublishedRagPipelineRunApi: + def test_published_run_success(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + "response_mode": "blocking", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.helper.compact_generate_response", + return_value={"ok": True}, + ), + ): + result = method(api, pipeline) + assert result == {"ok": True} + + def test_published_run_rate_limit(self, app): + api = PublishedRagPipelineRunApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "inputs": {}, + "datasource_type": "x", + "datasource_info_list": [], + "start_node_id": "n", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.PipelineGenerateService.generate", + side_effect=InvokeRateLimitError("limit"), + ), + ): + with pytest.raises(InvokeRateLimitHttpError): + method(api, pipeline) + + +class TestDefaultBlockConfigApi: + def test_get_block_config_success(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_default_block_config.return_value = {"k": "v"} + + with ( + app.test_request_context("/?q={}"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "llm") + assert result == {"k": "v"} + + def test_get_block_config_invalid_json(self, app): + api = DefaultRagPipelineBlockConfigApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + with app.test_request_context("/?q=bad-json"): + with pytest.raises(ValueError): + method(api, pipeline, "llm") + + +class TestPublishedAllRagPipelineApi: + def test_get_published_workflows_success(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + service = MagicMock() + service.get_all_published_workflow.return_value = ([{"id": "w1"}], False) + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + + assert result["items"] == [{"id": "w1"}] + assert result["has_more"] is False + + def test_get_published_workflows_forbidden(self, app): + api = PublishedAllRagPipelineApi() + method = unwrap(api.get) + + pipeline = MagicMock() + user = MagicMock(id="u1") + + with ( + app.test_request_context("/?user_id=u2"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + with pytest.raises(Forbidden): + method(api, pipeline) + + +class TestRagPipelineByIdApi: + def test_patch_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock(tenant_id="t1") + user = MagicMock(id="u1") + + workflow = MagicMock() + + service = MagicMock() + service.update_workflow.return_value = workflow + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + payload = {"marked_name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "w1") + + assert result == workflow + + def test_patch_no_fields(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.patch) + + pipeline = MagicMock() + user = MagicMock() + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + ): + result, status = method(api, pipeline, "w1") + assert status == 400 + + +class TestRagPipelineWorkflowLastRunApi: + def test_last_run_success(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + workflow = MagicMock() + node_exec = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = workflow + service.get_node_last_run.return_value = node_exec + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline, "node1") + assert result == node_exec + + def test_last_run_not_found(self, app): + api = RagPipelineWorkflowLastRunApi() + method = unwrap(api.get) + + pipeline = MagicMock() + + service = MagicMock() + service.get_draft_workflow.return_value = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + with pytest.raises(NotFound): + method(api, pipeline, "node1") + + +class TestRagPipelineDatasourceVariableApi: + def test_set_datasource_variables_success(self, app): + api = RagPipelineDatasourceVariableApi() + method = unwrap(api.post) + + pipeline = MagicMock() + user = MagicMock() + + payload = { + "datasource_type": "db", + "datasource_info": {}, + "start_node_id": "n1", + "start_node_title": "Node", + } + + service = MagicMock() + service.set_datasource_variables.return_value = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.current_account_with_tenant", + return_value=(user, "t"), + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.RagPipelineService", + return_value=service, + ), + ): + result = method(api, pipeline) + assert result is not None diff --git a/api/tests/unit_tests/controllers/console/datasets/test_data_source.py b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py new file mode 100644 index 0000000000..3060062adf --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_data_source.py @@ -0,0 +1,444 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import NotFound + +from controllers.console.datasets import data_source +from controllers.console.datasets.data_source import ( + DataSourceApi, + DataSourceNotionApi, + DataSourceNotionDatasetSyncApi, + DataSourceNotionDocumentSyncApi, + DataSourceNotionListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.data_source.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def mock_engine(): + with patch.object( + type(data_source.db), + "engine", + new_callable=PropertyMock, + return_value=MagicMock(), + ): + yield + + +class TestDataSourceApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + binding = MagicMock( + id="b1", + provider="notion", + created_at="now", + disabled=False, + source_info={}, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: [binding]), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"][0]["is_bound"] is True + + def test_get_no_bindings(self, app, patch_tenant): + api = DataSourceApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api) + + assert status == 200 + assert response["data"] == [] + + def test_patch_enable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "enable") + + assert status == 200 + assert binding.disabled is False + + def test_patch_disable_binding(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch("controllers.console.datasets.data_source.db.session.add"), + patch("controllers.console.datasets.data_source.db.session.commit"), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + response, status = method(api, "b1", "disable") + + assert status == 200 + assert binding.disabled is True + + def test_patch_binding_not_found(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = None + + with pytest.raises(NotFound): + method(api, "b1", "enable") + + def test_patch_enable_already_enabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=False) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "enable") + + def test_patch_disable_already_disabled(self, app, patch_tenant, mock_engine): + api = DataSourceApi() + method = unwrap(api.patch) + + binding = MagicMock(id="b1", disabled=True) + + with ( + app.test_request_context("/"), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalar_one_or_none.return_value = binding + + with pytest.raises(ValueError): + method(api, "b1", "disable") + + +class TestDataSourceNotionListApi: + def test_get_credential_not_found(self, app, patch_tenant): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_get_success_no_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + response, status = method(api) + + assert status == 200 + + def test_get_success_with_dataset_id(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + page = MagicMock( + page_id="p1", + page_name="Page 1", + type="page", + parent_id="parent", + page_icon=None, + ) + + online_document_message = MagicMock( + result=[ + MagicMock( + workspace_id="w1", + workspace_name="My Workspace", + workspace_icon="icon", + pages=[page], + ) + ] + ) + + dataset = MagicMock(data_source_type="notion_import") + document = MagicMock(data_source_info='{"notion_page_id": "p1"}') + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session") as mock_session_class, + patch( + "core.datasource.datasource_manager.DatasourceManager.get_datasource_runtime", + return_value=MagicMock( + get_online_document_pages=lambda **kw: iter([online_document_message]), + datasource_provider_type=lambda: None, + ), + ), + ): + mock_session = MagicMock() + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session.scalars.return_value.all.return_value = [document] + + response, status = method(api) + + assert status == 200 + + def test_get_invalid_dataset_type(self, app, patch_tenant, mock_engine): + api = DataSourceNotionListApi() + method = unwrap(api.get) + + dataset = MagicMock(data_source_type="other_type") + + with ( + app.test_request_context("/?credential_id=c1&dataset_id=ds1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"token": "t"}, + ), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.data_source.Session"), + ): + with pytest.raises(ValueError): + method(api) + + +class TestDataSourceNotionApi: + def test_get_preview_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.get) + + extractor = MagicMock(extract=lambda: [MagicMock(page_content="hello")]) + + with ( + app.test_request_context("/?credential_id=c1"), + patch( + "controllers.console.datasets.data_source.DatasourceProviderService.get_datasource_credentials", + return_value={"integration_secret": "t"}, + ), + patch( + "controllers.console.datasets.data_source.NotionExtractor", + return_value=extractor, + ), + ): + response, status = method(api, "p1", "page") + + assert status == 200 + + def test_post_indexing_estimate_success(self, app, patch_tenant): + api = DataSourceNotionApi() + method = unwrap(api.post) + + payload = { + "notion_info_list": [ + { + "workspace_id": "w1", + "credential_id": "c1", + "pages": [{"page_id": "p1", "type": "page"}], + } + ], + "process_rule": {"rules": {}}, + "doc_form": "text_model", + "doc_language": "English", + } + + with ( + app.test_request_context("/", method="POST", json=payload, headers={"Content-Type": "application/json"}), + patch( + "controllers.console.datasets.data_source.DocumentService.estimate_args_validate", + ), + patch( + "controllers.console.datasets.data_source.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"total_pages": 1}), + ), + ): + response, status = method(api) + + assert status == 200 + + +class TestDataSourceNotionDatasetSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document_by_dataset_id", + return_value=[MagicMock(id="d1")], + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app, patch_tenant): + api = DataSourceNotionDatasetSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDataSourceNotionDocumentSyncApi: + def test_get_success(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.document_indexing_sync_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_document_not_found(self, app, patch_tenant): + api = DataSourceNotionDocumentSyncApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.data_source.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.data_source.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py new file mode 100644 index 0000000000..f9fc2ac397 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -0,0 +1,1926 @@ +import datetime +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets import ( + DatasetApi, + DatasetApiBaseUrlApi, + DatasetApiDeleteApi, + DatasetApiKeyApi, + DatasetAutoDisableLogApi, + DatasetEnableApiApi, + DatasetErrorDocs, + DatasetIndexingEstimateApi, + DatasetIndexingStatusApi, + DatasetListApi, + DatasetPermissionUserListApi, + DatasetQueryApi, + DatasetRelatedAppListApi, + DatasetRetrievalSettingApi, + DatasetRetrievalSettingMockApi, + DatasetUseCheckApi, +) +from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.provider_manager import ProviderManager +from models.enums import CreatorUserRole +from models.model import ApiToken, UploadFile +from services.dataset_service import DatasetPermissionService, DatasetService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDatasetList: + def _mock_dataset_dict(self, **overrides): + base = { + "id": "ds-1", + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "permission": "only_me", + } + base.update(overrides) + return base + + def _mock_user(self): + user = MagicMock() + user.is_dataset_editor = True + return user + + def test_get_success_basic(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["embedding_available"] is True + + def test_get_with_ids_filter(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?ids=1&ids=2"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets_by_ids", + return_value=(datasets, 2), + ) as by_ids_mock, + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + by_ids_mock.assert_called_once() + assert status == 200 + assert resp["total"] == 2 + + def test_get_with_tag_ids(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict()] + + with app.test_request_context("/datasets?tag_ids=tag1"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert status == 200 + + def test_embedding_available_false(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [ + self._mock_dataset_dict( + indexing_technique="high_quality", + embedding_model="text-embed", + embedding_model_provider="openai", + ) + ] + + config = MagicMock() + config.get_models.return_value = [] # model not available + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=config, + ), + ): + resp, status = method(api) + + assert resp["data"][0]["embedding_available"] is False + + def test_partial_members_permission(self, app): + api = DatasetListApi() + method = unwrap(api.get) + + current_user = self._mock_user() + datasets = [MagicMock()] + marshaled = [self._mock_dataset_dict(permission="partial_members")] + + with app.test_request_context("/datasets"): + with ( + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_datasets", + return_value=(datasets, 1), + ), + patch( + "controllers.console.datasets.datasets.db.session.execute", + return_value=MagicMock(all=lambda: [("ds-1", "u1")]), + ), + patch( + "controllers.console.datasets.datasets.marshal", + return_value=marshaled, + ), + patch.object( + ProviderManager, + "get_configurations", + return_value=MagicMock(get_models=lambda **_: []), + ), + ): + resp, status = method(api) + + assert resp["data"][0]["partial_member_list"] == ["u1"] + + +class TestDatasetListApiPost: + def test_post_success(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "My Dataset", + "description": "desc", + "indexing_technique": "economy", + "provider": "vendor", + } + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + # ---- minimal required fields for marshal ---- + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_post_forbidden(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "test"} + + user = MagicMock() + user.is_dataset_editor = False + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate"} + + user = MagicMock() + user.is_dataset_editor = True + + with ( + app.test_request_context("/datasets", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch.object( + DatasetService, + "create_empty_dataset", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_post_invalid_payload_missing_name(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}): + with pytest.raises(ValueError): + method(api) + + def test_post_invalid_indexing_technique(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "indexing_technique": "invalid-tech", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid indexing technique"): + method(api) + + def test_post_invalid_provider(self, app): + api = DatasetListApi() + method = unwrap(api.post) + + payload = { + "name": "bad", + "provider": "unknown", + } + + with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError, match="Invalid provider"): + method(api) + + +class TestDatasetApiGet: + def test_get_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "123e4567-e89b-12d3-a456-426614174000" + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding models exist → embedding_available stays True + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, status = method(api, dataset_id) + + assert status == 200 + assert data["embedding_available"] is True + + def test_get_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "missing-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + dataset = MagicMock() + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + method(api, dataset_id) + + def test_get_high_quality_embedding_unavailable(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "high_quality" + dataset.embedding_model = "text-embedding" + dataset.embedding_model_provider = "openai" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + dataset.permission = "only_me" + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + # embedding model NOT configured + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["embedding_available"] is False + + def test_get_partial_members_permission(self, app): + api = DatasetApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + dataset.permission = "partial_members" + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + partial_members = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=partial_members, + ), + patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + ): + provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] + + data, _ = method(api, dataset_id) + + assert data["partial_member_list"] == partial_members + + +class TestDatasetApiPatch: + def test_patch_success_basic(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "name": "updated-name", + "description": "updated description", + } + + user = MagicMock() + tenant_id = "tenant-1" + + dataset = MagicMock() + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["partial_member_list"] == [] + + def test_patch_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/datasets/missing"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, "missing") + + def test_patch_permission_denied(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + dataset = MagicMock() + + payload = {"name": "x"} + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetPermissionService, + "check_permission", + side_effect=Forbidden("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_patch_partial_members_update(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "partial_members", + "partial_member_list": [{"id": "u1"}, {"id": "u2"}], + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "partial_members" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "update_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=payload["partial_member_list"], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == payload["partial_member_list"] + + def test_patch_clear_partial_members(self, app): + api = DatasetApi() + method = unwrap(api.patch) + + dataset_id = "dataset-id" + + payload = { + "permission": "only_me", + } + + dataset = MagicMock() + dataset.id = dataset_id + dataset.permission = "only_me" + dataset.indexing_technique = "economy" + dataset.embedding_model_provider = None + + dataset.embedding_available = True + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.is_multimodal = False + dataset.documents = [] + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "check_permission", + return_value=None, + ), + patch.object( + DatasetService, + "update_dataset", + return_value=dataset, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + patch.object( + DatasetPermissionService, + "get_dataset_partial_member_list", + return_value=[], + ), + ): + result, _ = method(api, dataset_id) + + assert result["partial_member_list"] == [] + + +class TestDatasetApiDelete: + def test_delete_success(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=True, + ), + patch.object( + DatasetPermissionService, + "clear_partial_member_list", + return_value=None, + ), + ): + result, status = method(api, dataset_id) + + assert status == 204 + assert result == {"result": "success"} + + def test_delete_forbidden_no_permission(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = False + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_delete_dataset_not_found(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "missing-dataset" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + return_value=False, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_delete_dataset_in_use(self, app): + api = DatasetApi() + method = unwrap(api.delete) + + dataset_id = "dataset-id" + user = MagicMock() + user.has_edit_permission = True + user.is_dataset_operator = False + + with ( + app.test_request_context(f"/datasets/{dataset_id}"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object( + DatasetService, + "delete_dataset", + side_effect=services.errors.dataset.DatasetInUseError(), + ), + ): + with pytest.raises(DatasetInUseError): + method(api, dataset_id) + + +class TestDatasetUseCheckApi: + def test_get_use_check_true(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=True, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": True} + + def test_get_use_check_false(self, app): + api = DatasetUseCheckApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + with ( + app.test_request_context(f"/datasets/{dataset_id}/use-check"), + patch.object( + DatasetService, + "dataset_use_check", + return_value=False, + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result == {"is_using": False} + + +class TestDatasetQueryApi: + def test_get_queries_success(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock(), MagicMock()] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 2), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["total"] == 2 + assert response["page"] == 1 + assert response["limit"] == 20 + assert response["has_more"] is False + assert len(response["data"]) == 2 + + def test_get_queries_dataset_not_found(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_get_queries_permission_denied(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + + with ( + app.test_request_context("/datasets/queries"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, dataset_id) + + def test_get_queries_pagination_has_more(self, app): + api = DatasetQueryApi() + method = unwrap(api.get) + + dataset_id = "dataset-id" + current_user = MagicMock() + + dataset = MagicMock() + dataset.id = dataset_id + + queries = [MagicMock() for _ in range(20)] + + with ( + app.test_request_context("/datasets/queries?page=1&limit=20"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + return_value=None, + ), + patch.object( + DatasetService, + "get_dataset_queries", + return_value=(queries, 40), + ), + ): + response, status = method(api, dataset_id) + + assert status == 200 + assert response["has_more"] is True + assert len(response["data"]) == 20 + + +class TestDatasetIndexingEstimateApi: + def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile: + upload_file = UploadFile( + tenant_id=tenant_id, + storage_type="local", + key="key", + name="name.txt", + size=1, + extension="txt", + mime_type="text/plain", + created_by_role=CreatorUserRole.ACCOUNT, + created_by="user-1", + created_at=datetime.datetime.now(tz=datetime.UTC), + used=False, + ) + upload_file.id = file_id + return upload_file + + def _base_payload(self): + return { + "info_list": { + "data_source_type": "upload_file", + "file_info_list": { + "file_ids": ["file-1"], + }, + }, + "process_rule": {"chunk_size": 100}, + "indexing_technique": "high_quality", + "doc_form": "text_model", + "doc_language": "English", + "dataset_id": None, + } + + def test_post_success_upload_file(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + mock_file = self._upload_file() + + mock_response = MagicMock() + mock_response.model_dump.return_value = {"tokens": 100} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + return_value=mock_response, + ), + ): + response, status = method(api) + + assert status == 200 + assert response == {"tokens": 100} + + def test_post_file_not_found(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: None), + ), + ): + with pytest.raises(NotFound): + method(api) + + def test_post_llm_bad_request_error(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_provider_token_not_init(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api) + + def test_post_generic_exception(self, app): + api = DatasetIndexingEstimateApi() + method = unwrap(api.post) + mock_file = self._upload_file() + + payload = self._base_payload() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_file]), + ), + patch( + "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", + side_effect=Exception("boom"), + ), + ): + with pytest.raises(IndexingEstimateError): + method(api) + + +class TestDatasetRelatedAppListApi: + def test_get_success(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + app2 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=app2) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 2 + assert response["data"] == [app1, app2] + + def test_get_dataset_not_found(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + def test_get_permission_denied(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + def test_get_filters_none_apps(self, app): + api = DatasetRelatedAppListApi() + method = unwrap(api.get) + + dataset = MagicMock() + dataset.id = "dataset-1" + + app1 = MagicMock() + + join1 = MagicMock(app=app1) + join2 = MagicMock(app=None) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_related_apps", + return_value=[join1, join2], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + assert response["data"] == [app1] + + +class TestDatasetIndexingStatusApi: + def test_get_success_with_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "completed" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert "data" in response + assert len(response["data"]) == 1 + + item = response["data"][0] + assert item["completed_segments"] == 3 + assert item["total_segments"] == 3 + + def test_get_success_no_documents(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: []), + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == {"data": []} + + def test_segment_counts_different_values(self, app): + api = DatasetIndexingStatusApi() + method = unwrap(api.get) + + document = MagicMock() + document.id = "doc-1" + document.indexing_status = "indexing" + document.processing_started_at = None + document.parsing_completed_at = None + document.cleaning_completed_at = None + document.splitting_completed_at = None + document.completed_at = None + document.paused_at = None + document.error = None + document.stopped_at = None + + # First count = completed segments, second = total segments + query_mock = MagicMock() + query_mock.where.side_effect = [ + MagicMock(count=lambda: 2), + MagicMock(count=lambda: 5), + ] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [document]), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=query_mock, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + item = response["data"][0] + assert item["completed_segments"] == 2 + assert item["total_segments"] == 5 + + +class TestDatasetApiKeyApi: + def test_get_api_keys_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.get) + + mock_key_1 = MagicMock(spec=ApiToken) + mock_key_2 = MagicMock(spec=ApiToken) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.scalars", + return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]), + ), + ): + response = method(api) + + assert "items" in response + assert response["items"] == [mock_key_1, mock_key_2] + + def test_post_create_api_key_success(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)), + ), + patch( + "controllers.console.datasets.datasets.ApiToken.generate_api_key", + return_value="dataset-abc123", + ), + patch( + "controllers.console.datasets.datasets.db.session.add", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + ): + response, status = method(api) + + assert status == 200 + assert isinstance(response, ApiToken) + assert response.token == "dataset-abc123" + assert response.type == "dataset" + + def test_post_exceed_max_keys(self, app): + api = DatasetApiKeyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)), + ), + ): + with pytest.raises(BadRequest) as exc_info: + method(api) + + assert exc_info.value.code == 400 + assert exc_info.value.data == { + "message": "Cannot create more than 10 API keys for this resource type.", + "custom": "max_keys_exceeded", + } + + +class TestDatasetApiDeleteApi: + def test_delete_success(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + mock_key = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)), + ), + patch( + "controllers.console.datasets.datasets.db.session.commit", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.db.session.delete", + return_value=None, + ), + ): + response, status = method(api, "api-key-id") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_key_not_found(self, app): + api = DatasetApiDeleteApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.db.session.query", + return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "api-key-id") + + +class TestDatasetEnableApiApi: + def test_enable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_disable_api(self, app): + api = DatasetEnableApiApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", + return_value=None, + ), + ): + response, status = method(api, "dataset-1", "disable") + + assert status == 200 + assert response["result"] == "success" + + +class TestDatasetApiBaseUrlApi: + def test_get_api_base_url_from_config(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + "https://example.com", + ), + ): + response = method(api) + + assert response["api_base_url"] == "https://example.com/v1" + + def test_get_api_base_url_from_request(self, app): + api = DatasetApiBaseUrlApi() + method = unwrap(api.get) + + with ( + app.test_request_context("http://localhost:5000/"), + patch( + "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", + None, + ), + ): + response = method(api) + + assert response["api_base_url"] == "http://localhost:5000/v1" + + +class TestDatasetRetrievalSettingApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.dify_config.VECTOR_STORE", + "qdrant", + ), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic", "hybrid"]}, + ), + ): + response = method(api) + + assert "retrieval_method" in response + + +class TestDatasetRetrievalSettingMockApi: + def test_get_success(self, app): + api = DatasetRetrievalSettingMockApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", + return_value={"retrieval_method": ["semantic"]}, + ), + ): + response = method(api, "milvus") + + assert response["retrieval_method"] == ["semantic"] + + +class TestDatasetErrorDocs: + def test_get_success(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + dataset = MagicMock() + error_doc = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id", + return_value=[error_doc], + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["total"] == 1 + + def test_get_dataset_not_found(self, app): + api = DatasetErrorDocs() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") + + +class TestDatasetPermissionUserListApi: + def test_get_success(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + users = [{"id": "u1"}, {"id": "u2"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list", + return_value=users, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response["data"] == users + + def test_get_permission_denied(self, app): + api = DatasetPermissionUserListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "dataset-1") + + +class TestDatasetAutoDisableLogApi: + def test_get_success(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + dataset = MagicMock() + logs = [{"reason": "quota"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs", + return_value=logs, + ), + ): + response, status = method(api, "dataset-1") + + assert status == 200 + assert response == logs + + def test_get_dataset_not_found(self, app): + api = DatasetAutoDisableLogApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-1") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py new file mode 100644 index 0000000000..dbe54ccb99 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_document.py @@ -0,0 +1,1379 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.datasets_document import ( + DatasetDocumentListApi, + DocumentApi, + DocumentBatchDownloadZipApi, + DocumentBatchIndexingEstimateApi, + DocumentBatchIndexingStatusApi, + DocumentDownloadApi, + DocumentGenerateSummaryApi, + DocumentIndexingEstimateApi, + DocumentIndexingStatusApi, + DocumentMetadataApi, + DocumentPipelineExecutionLogApi, + DocumentProcessingApi, + DocumentRetryApi, + DocumentStatusApi, + DocumentSummaryStatusApi, + GetProcessRuleApi, +) +from controllers.console.datasets.error import ( + DocumentAlreadyFinishedError, + DocumentIndexingError, + IndexingEstimateError, + InvalidActionError, + InvalidMetadataError, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def tenant_ctx(): + return (MagicMock(is_dataset_editor=True, id="u1"), "tenant-1") + + +@pytest.fixture +def patch_tenant(tenant_ctx): + with patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=tenant_ctx, + ): + yield + + +@pytest.fixture +def dataset(): + return MagicMock(id="ds-1", indexing_technique="economy", summary_index_setting={"enable": True}) + + +@pytest.fixture +def document(): + return MagicMock( + id="doc-1", + tenant_id="tenant-1", + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + doc_form="text", + archived=False, + is_paused=False, + dataset_process_rule=None, + ) + + +@pytest.fixture +def patch_dataset(dataset): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ): + yield + + +@pytest.fixture +def patch_permission(): + with patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ): + yield + + +class TestGetProcessRuleApi: + def test_get_default_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + response = method(api) + + assert "rules" in response + + def test_get_with_document_dataset_not_found(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api) + + +class TestDatasetDocumentListApi: + def test_get_with_fetch_true_counts_segments(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + doc = MagicMock(id="doc-1") + pagination = MagicMock(items=[doc], total=1) + + count_mock = MagicMock(return_value=2) + + with ( + app.test_request_context("/?fetch=true"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(count=count_mock)), + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["data"] + + def test_get_with_search_status_and_created_at_sort(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?keyword=test&status=enabled&sort=created_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.apply_display_status_filter", + side_effect=lambda q, s: q, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + resp = method(api, "ds-1") + + assert resp["total"] == 1 + + def test_get_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_post_success(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + payload = {"indexing_technique": "economy"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.document_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.save_document_with_dataset_id", + return_value=([MagicMock()], "batch-1"), + ), + ): + response = method(api, "ds-1") + + assert "documents" in response + + def test_post_forbidden(self, app): + api = DatasetDocumentListApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/", json={}), + patch.object(type(console_ns), "payload", {}), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1") + + def test_get_with_fetch_true_and_invalid_fetch(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?fetch=maybe"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_get_sort_hit_count(self, app, patch_tenant, patch_dataset, patch_permission): + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[], total=0) + + with ( + app.test_request_context("/?sort=hit_count"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 0 + + +class TestDocumentApi: + def test_get_success(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_invalid_metadata(self, app, patch_tenant): + api = DocumentApi() + method = unwrap(api.get) + + with app.test_request_context("/?metadata=wrong"), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(InvalidMetadataError): + method(api, "ds-1", "doc-1") + + def test_delete_success(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + api = DocumentApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_document", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1", "doc-1") + + +class TestDocumentDownloadApi: + def test_download_success(self, app, patch_tenant): + api = DocumentDownloadApi() + method = unwrap(api.get) + + document = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document_download_url", + return_value="url", + ), + ): + response = method(api, "ds-1", "doc-1") + + assert response["url"] == "url" + + +class TestDocumentProcessingApi: + def test_processing_forbidden_when_not_editor(self, app): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=False) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant"), + ), + patch.object(api, "get_document", return_value=MagicMock()), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1", "pause") + + def test_resume_from_error_state(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + doc = MagicMock(indexing_status="error", is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + _, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_resume_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="paused", is_paused=True) + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "resume") + + assert status == 200 + + def test_pause_success(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "pause") + + assert status == 200 + + def test_pause_invalid(self, app, patch_tenant): + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="completed") + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "pause") + + +class TestDocumentMetadataApi: + def test_put_metadata_schema_filtering(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + + payload = { + "doc_type": "invoice", + "doc_metadata": {"amount": 10, "invalid": "x"}, + } + + schema = {"amount": int} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"invoice": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + method(api, "ds-1", "doc-1") + + assert doc.doc_metadata == {"amount": 10} + + def test_put_success(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + document = MagicMock() + + payload = {"doc_type": "others", "doc_metadata": {"a": 1}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_put_invalid_payload(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + with app.test_request_context("/", json={}), patch.object(api, "get_document", return_value=MagicMock()): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_put_invalid_doc_type(self, app, patch_tenant): + api = DocumentMetadataApi() + method = unwrap(api.put) + + payload = {"doc_type": "invalid", "doc_metadata": {}} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=MagicMock()), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"others": {}}, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + +class TestDocumentStatusApi: + def test_patch_success(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "enable") + + assert status == 200 + + def test_patch_invalid_action(self, app, patch_tenant, patch_dataset): + api = DocumentStatusApi() + method = unwrap(api.patch) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.batch_update_document_status", + side_effect=ValueError("x"), + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "enable") + + +class TestDocumentRetryApi: + def test_retry_archived_document_skipped(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + doc = MagicMock(indexing_status="indexing") + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=doc, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + ) as retry_mock, + ): + resp, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + def test_retry_success(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status="indexing", archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.check_archived", + return_value=False, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", [document]) + + def test_retry_skips_completed_document(self, app, patch_tenant, patch_dataset): + api = DocumentRetryApi() + method = unwrap(api.post) + + payload = {"document_ids": ["doc-1"]} + + document = MagicMock(indexing_status="completed", archived=False) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.retry_document", + return_value=None, + ) as retry_mock, + ): + response, status = method(api, "ds-1") + + assert status == 204 + retry_mock.assert_called_once_with("ds-1", []) + + +class TestDocumentPipelineExecutionLogApi: + def test_get_log_success(self, app, patch_tenant, patch_dataset): + api = DocumentPipelineExecutionLogApi() + method = unwrap(api.get) + + log = MagicMock( + datasource_info="{}", + datasource_type="file", + input_data={}, + datasource_node_id="n1", + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + filter_by=lambda **k: MagicMock(order_by=lambda *a: MagicMock(first=lambda: log)) + ), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApi: + def test_generate_summary_missing_documents(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[MagicMock(id="doc-1")], + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_generate_not_enabled(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="high_quality", summary_index_setting={"enable": False}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + def test_generate_summary_success_with_qa_skip(self, app, patch_tenant, patch_permission): + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock( + indexing_technique="high_quality", + summary_index_setting={"enable": True}, + ) + + doc1 = MagicMock(id="doc-1", doc_form="qa_model") + doc2 = MagicMock(id="doc-2", doc_form="text") + + payload = {"document_list": ["doc-1", "doc-2"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_documents_by_ids", + return_value=[doc1, doc2], + ), + patch( + "controllers.console.datasets.datasets_document.generate_summary_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 200 + + +class TestDocumentSummaryStatusApi: + def test_get_success(self, app, patch_tenant, patch_permission): + api = DocumentSummaryStatusApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "services.summary_index_service.SummaryIndexService.get_document_summary_status_detail", + return_value={"total_segments": 0}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentIndexingEstimateApi: + def test_indexing_estimate_file_not_found(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + query_mock = MagicMock() + query_mock.where.return_value.first.return_value = None + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=query_mock, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_indexing_estimate_generic_exception(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + mock_indexing_runner = MagicMock() + mock_indexing_runner.indexing_estimate.side_effect = RuntimeError("Some indexing error") + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=MagicMock(return_value=MagicMock(first=MagicMock(return_value=upload_file))) + ), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner", + return_value=mock_indexing_runner, + ), + ): + with pytest.raises(IndexingEstimateError): + method(api, "ds-1", "doc-1") + + def test_get_finished(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock(indexing_status="completed") + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(DocumentAlreadyFinishedError): + method(api, "ds-1", "doc-1") + + +class TestDocumentBatchDownloadZipApi: + def test_post_no_documents(self, app, patch_tenant): + api = DocumentBatchDownloadZipApi() + method = unwrap(api.post) + + payload = {"document_ids": []} + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDatasetDocumentListApiDelete: + def test_delete_success(self, app, patch_tenant, patch_dataset): + """Test successful deletion of documents""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1&document_id=doc-2"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + return_value=None, + ), + ): + response, status = method(api, "ds-1") + + assert status == 204 + + def test_delete_indexing_error(self, app, patch_tenant, patch_dataset): + """Test deletion with indexing error""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.delete_documents", + side_effect=services.errors.document.DocumentIndexingError(), + ), + ): + with pytest.raises(DocumentIndexingError): + method(api, "ds-1") + + def test_delete_dataset_not_found(self, app, patch_tenant): + """Test deletion when dataset not found""" + api = DatasetDocumentListApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + +class TestDocumentBatchIndexingEstimateApi: + def test_batch_indexing_estimate_website(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status="indexing", + data_source_type="website_crawl", + data_source_info_dict={ + "provider": "firecrawl", + "job_id": "j1", + "url": "https://x.com", + "mode": "single", + "only_main_content": True, + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 2}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_indexing_estimate_notion(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + doc = MagicMock( + indexing_status="indexing", + data_source_type="notion_import", + data_source_info_dict={ + "credential_id": "c1", + "notion_workspace_id": "w1", + "notion_page_id": "p1", + "type": "page", + }, + doc_form="text", + ) + + with ( + app.test_request_context("/"), + patch.object(api, "get_batch_documents", return_value=[doc]), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 1}), + ), + ): + resp, status = method(api, "ds-1", "batch-1") + + assert status == 200 + + def test_batch_estimate_unsupported_datasource(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="unknown", + data_source_info_dict={}, + doc_form="text", + ) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", return_value=[document]): + with pytest.raises(ValueError): + method(api, "ds-1", "batch-1") + + def test_get_batch_estimate_invalid_batch(self, app, patch_tenant): + """Test batch estimation with invalid batch""" + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentBatchIndexingStatusApi: + def test_get_batch_status_invalid_batch(self, app, patch_tenant): + """Test batch status with invalid batch""" + api = DocumentBatchIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_batch_documents", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-batch") + + +class TestDocumentIndexingStatusApi: + def test_get_status_document_not_found(self, app, patch_tenant): + """Test getting status for non-existent document""" + api = DocumentIndexingStatusApi() + method = unwrap(api.get) + + with app.test_request_context("/"), patch.object(api, "get_document", side_effect=NotFound()): + with pytest.raises(NotFound): + method(api, "ds-1", "invalid-doc") + + +class TestDocumentApiMetadata: + def test_get_with_only_option(self, app, patch_tenant): + """Test get with 'only' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None, doc_metadata_details=[]) + + with ( + app.test_request_context("/?metadata=only"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_with_without_option(self, app, patch_tenant): + """Test get with 'without' metadata option""" + api = DocumentApi() + method = unwrap(api.get) + + document = MagicMock(dataset_process_rule=None) + + with ( + app.test_request_context("/?metadata=without"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + +class TestDocumentGenerateSummaryApiSuccess: + def test_generate_not_enabled_high_quality(self, app, patch_tenant, patch_permission): + """Test summary generation on non-high-quality dataset""" + api = DocumentGenerateSummaryApi() + method = unwrap(api.post) + + dataset = MagicMock(indexing_technique="economy", summary_index_setting={"enable": True}) + + payload = {"document_list": ["doc-1"]} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=dataset, + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1") + + +class TestDocumentProcessingApiResume: + def test_resume_invalid_status(self, app, patch_tenant): + """Test resume on non-paused document""" + api = DocumentProcessingApi() + method = unwrap(api.patch) + + document = MagicMock(indexing_status="completed", is_paused=False) + + with app.test_request_context("/"), patch.object(api, "get_document", return_value=document): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "resume") + + +class TestDocumentPermissionCases: + def test_document_batch_get_permission_denied(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "batch-1") + + def test_document_batch_get_documents_not_found(self, app, patch_tenant): + api = DocumentBatchIndexingEstimateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch.object(api, "get_batch_documents", return_value=None), + ): + response, status = method(api, "ds-1", "batch-1") + + assert status == 200 + assert response == { + "tokens": 0, + "total_price": 0, + "currency": "USD", + "total_segments": 0, + "preview": [], + } + + def test_document_tenant_mismatch(self, app): + api = DocumentApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + document = MagicMock( + tenant_id="other-tenant", + dataset_process_rule=None, + ) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), # ✅ prevents real DB call + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_process_rules", + return_value={}, + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_process_rule_get_by_document_success(self, app, patch_tenant): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + process_rule = MagicMock(mode="custom", rules_dict={"a": 1}) + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock( + where=lambda *a: MagicMock( + order_by=lambda *b: MagicMock(limit=lambda n: MagicMock(one_or_none=lambda: process_rule)) + ) + ), + ), + ): + result = method(api) + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["mode"] == "custom" + + def test_process_rule_permission_denied(self, app): + api = GetProcessRuleApi() + method = unwrap(api.get) + + document = MagicMock(dataset_id="ds-1") + + with ( + app.test_request_context("/?document_id=doc-1"), + patch( + "controllers.console.datasets.datasets_document.current_account_with_tenant", + return_value=(MagicMock(is_dataset_editor=True), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_document.db.get_or_404", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestDocumentListAdvancedCases: + def test_document_list_with_multiple_sort_options(self, app, patch_tenant, patch_dataset, patch_permission): + """Test document list with different sort options""" + api = DatasetDocumentListApi() + method = unwrap(api.get) + + pagination = MagicMock(items=[MagicMock()], total=1) + + with ( + app.test_request_context("/?sort=updated_at"), + patch( + "controllers.console.datasets.datasets_document.db.paginate", + return_value=pagination, + ), + patch( + "controllers.console.datasets.datasets_document.DocumentService.enrich_documents_with_summary_index_status", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_document.marshal", + return_value=[{"id": "doc-1"}], + ), + ): + response = method(api, "ds-1") + + assert response["total"] == 1 + + def test_document_metadata_with_schema_validation(self, app, patch_tenant): + """Test document metadata update with schema validation""" + api = DocumentMetadataApi() + method = unwrap(api.put) + + doc = MagicMock() + payload = { + "doc_type": "contract", + "doc_metadata": {"amount": 5000, "currency": "USD", "invalid_field": "x"}, + } + + schema = {"amount": int, "currency": str} + + with ( + app.test_request_context("/", json=payload), + patch.object(api, "get_document", return_value=doc), + patch( + "controllers.console.datasets.datasets_document.DocumentService.DOCUMENT_METADATA_SCHEMA", + {"contract": schema}, + ), + patch( + "controllers.console.datasets.datasets_document.db.session.commit", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert doc.doc_metadata == {"amount": 5000, "currency": "USD"} + + +class TestDocumentIndexingEdgeCases: + def test_document_indexing_with_extraction_setting(self, app, patch_tenant): + api = DocumentIndexingEstimateApi() + method = unwrap(api.get) + + document = MagicMock( + indexing_status="indexing", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + tenant_id="tenant-1", + doc_form="text", + dataset_process_rule=None, + ) + + upload_file = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(api, "get_document", return_value=document), + patch( + "controllers.console.datasets.datasets_document.db.session.query", + return_value=MagicMock(where=lambda *a: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_document.ExtractSetting", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_document.IndexingRunner.indexing_estimate", + return_value=MagicMock(model_dump=lambda: {"tokens": 5}), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py new file mode 100644 index 0000000000..e67e4daad9 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -0,0 +1,1252 @@ +from datetime import datetime +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.app.error import ProviderNotInitializeError +from controllers.console.datasets.datasets_segments import ( + ChildChunkAddApi, + ChildChunkUpdateApi, + DatasetDocumentSegmentAddApi, + DatasetDocumentSegmentApi, + DatasetDocumentSegmentBatchImportApi, + DatasetDocumentSegmentListApi, + DatasetDocumentSegmentUpdateApi, + _get_segment_with_summary, +) +from controllers.console.datasets.error import ( + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + InvalidActionError, +) +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from models.dataset import ChildChunk, DocumentSegment +from models.model import UploadFile + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def _segment(): + return SimpleNamespace( + id="s1", + position=1, + document_id="d1", + content="c", + sign_content="c", + answer="a", + word_count=1, + tokens=1, + keywords=[], + index_node_id="n1", + index_node_hash="h", + hit_count=0, + enabled=True, + disabled_at=None, + disabled_by=None, + status="normal", + created_by="u1", + created_at=datetime.utcnow(), + updated_at=datetime.utcnow(), + updated_by="u1", + indexing_at=None, + completed_at=None, + error=None, + stopped_at=None, + child_chunks=[], + attachments=[], + summary=None, + ) + + +def test_get_segment_with_summary(monkeypatch): + segment = _segment() + summary = SimpleNamespace(summary_content="summary") + + monkeypatch.setattr( + "services.summary_index_service.SummaryIndexService.get_segment_summary", + lambda *_args, **_kwargs: summary, + ) + + result = _get_segment_with_summary(segment, dataset_id="d1") + + assert result["summary"] == "summary" + + +class TestDatasetDocumentSegmentListApi: + def test_get_success(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + + pagination = MagicMock() + pagination.items = [segment] + pagination.total = 1 + pagination.pages = 1 + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + + def test_get_dataset_not_found(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_get_permission_denied(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/?segment_id=s1&segment_id=s2"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segments_status", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "enable") + + assert status == 200 + assert response["result"] == "success" + + def test_patch_document_indexing_in_progress(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.id = "doc-1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=b"running", + ), + ): + with pytest.raises(InvalidActionError): + method(api, "ds-1", "doc-1", "disable") + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + def test_patch_provider_token_not_init(self, app): + api = DatasetDocumentSegmentApi() + method = unwrap(api.patch) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock(id="doc-1") + + with ( + app.test_request_context("/?segment_id=s1"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "enable") + + +class TestDatasetDocumentSegmentAddApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "hello"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + segment.id = "seg-1" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "seg-1"}, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["data"]["id"] == "seg-1" + + def test_post_llm_bad_request(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + def test_post_provider_token_not_init(self, app): + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=ProviderTokenNotInitError("token missing"), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1") + + +class TestDatasetDocumentSegmentUpdateApi: + def test_patch_success(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "updated"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + document.doc_form = "text" + + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.segment_create_args_validate", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.update_segment", + return_value=segment, + ), + patch( + "controllers.console.datasets.datasets_segments._get_segment_with_summary", + return_value={"id": "seg-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert "data" in response + + def test_patch_llm_bad_request(self, app): + api = DatasetDocumentSegmentUpdateApi() + method = unwrap(api.patch) + + payload = {"content": "x"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock( + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embed", + ) + + document = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_model_setting", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.ModelManager.get_model_instance", + side_effect=LLMBadRequestError(), + ), + ): + with pytest.raises(ProviderNotInitializeError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestDatasetDocumentSegmentBatchImportApi: + def test_post_success(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock(spec=UploadFile) + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + return_value=True, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 200 + assert response["job_status"] == "waiting" + + def test_post_dataset_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_document_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_upload_file_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: None)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_post_invalid_file_type(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.txt" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(ValueError): + method(api, "ds-1", "doc-1") + + def test_post_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + payload = {"upload_file_id": "file-1"} + + upload_file = MagicMock() + upload_file.name = "test.csv" + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.setnx", + side_effect=Exception("redis down"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_get_job_not_found_in_redis(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, job_id="job-1") + + +class TestChildChunkAddApi: + def test_post_success(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + dataset.indexing_technique = "economy" + + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock(spec=ChildChunk) + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + return_value=child_chunk, + ), + patch( + "controllers.console.datasets.datasets_segments.marshal", + return_value={"id": "cc-1"}, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1") + + assert status == 200 + assert response["data"]["id"] == "cc-1" + + def test_post_child_chunk_indexing_error(self, app): + api = ChildChunkAddApi() + method = unwrap(api.post) + + payload = {"content": "child"} + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock(indexing_technique="economy") + document = MagicMock() + segment = MagicMock() + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_child_chunk", + side_effect=services.errors.chunk.ChildChunkIndexingError("fail"), + ), + ): + with pytest.raises(ChildChunkIndexingError): + method(api, "ds-1", "doc-1", "seg-1") + + +class TestChildChunkUpdateApi: + def test_delete_success(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock() + user.is_dataset_editor = True + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + return_value=None, + ), + ): + response, status = method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + assert status == 204 + assert response["result"] == "success" + + def test_delete_child_chunk_index_error(self, app): + api = ChildChunkUpdateApi() + method = unwrap(api.delete) + + user = MagicMock(is_dataset_editor=True) + + dataset = MagicMock() + document = MagicMock() + segment = MagicMock() + child_chunk = MagicMock() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + side_effect=[ + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: segment)), + MagicMock(where=lambda *a, **k: MagicMock(first=lambda: child_chunk)), + ], + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.delete_child_chunk", + side_effect=services.errors.chunk.ChildChunkDeleteIndexError("fail"), + ), + ): + with pytest.raises(ChildChunkDeleteIndexError): + method(api, "ds-1", "doc-1", "seg-1", "cc-1") + + +class TestSegmentListAdvancedCases: + def test_segment_list_with_keyword_filter(self, app): + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + dataset = MagicMock() + document = MagicMock() + + segment = MagicMock(spec=DocumentSegment) + segment.id = "seg-1" + segment.keywords = ["test"] + segment.enabled = True + + pagination = MagicMock(items=[segment], total=1, pages=1) + + with ( + app.test_request_context("/?keyword=test"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.paginate", + return_value=pagination, + ), + patch( + "services.summary_index_service.SummaryIndexService.get_segments_summaries", + return_value={}, + ), + ): + result = method(api, "ds-1", "doc-1") + + if isinstance(result, tuple): + response, status = result + else: + response, status = result, 200 + + assert status == 200 + assert response["total"] == 1 + + def test_segment_list_permission_denied(self, app): + """Test segment list with permission denied""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=MagicMock(), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("No permission"), + ), + ): + with pytest.raises(Forbidden): + method(api, "ds-1", "doc-1") + + def test_segment_list_dataset_not_found(self, app): + """Test segment list with dataset not found""" + api = DatasetDocumentSegmentListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(MagicMock(), "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + +class TestSegmentOperationCases: + def test_segment_add_with_provider_token_error(self, app): + """Test segment add with provider token not initialized""" + api = DatasetDocumentSegmentAddApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + + payload = {"content": "new content", "answer": None} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.SegmentService.create_segment", + side_effect=ProviderTokenNotInitError("Token not init"), + ), + ): + with pytest.raises(ProviderTokenNotInitError): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_document_not_found(self, app): + """Test batch import with document not found""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_invalid_file(self, app): + """Test batch import with invalid file type""" + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = None # File not found + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1", "doc-1") + + def test_batch_import_with_async_task_failure(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.post) + + user = MagicMock(is_dataset_editor=True) + dataset = MagicMock() + document = MagicMock() + upload_file = MagicMock(spec=UploadFile, extension="csv", id="file-1") + upload_file.name = "test.csv" + + payload = {"upload_file_id": "file-1"} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.get_dataset", + return_value=dataset, + ), + patch( + "controllers.console.datasets.datasets_segments.DocumentService.get_document", + return_value=document, + ), + patch( + "controllers.console.datasets.datasets_segments.db.session.query", + return_value=MagicMock(where=lambda *a, **k: MagicMock(first=lambda: upload_file)), + ), + patch( + "controllers.console.datasets.datasets_segments.DatasetService.check_dataset_permission", + return_value=None, + ), + patch( + "controllers.console.datasets.datasets_segments.batch_create_segment_to_index_task.delay", + side_effect=Exception("Task failed"), + ), + ): + response, status = method(api, "ds-1", "doc-1") + + assert status == 500 + assert "error" in response + + def test_batch_import_get_job_not_found(self, app): + api = DatasetDocumentSegmentBatchImportApi() + method = unwrap(api.get) + + user = MagicMock(is_dataset_editor=True) + + with ( + app.test_request_context("/?job_id=invalid-job"), + patch( + "controllers.console.datasets.datasets_segments.current_account_with_tenant", + return_value=(user, "tenant-1"), + ), + patch( + "controllers.console.datasets.datasets_segments.redis_client.get", + return_value=None, + ), + ): + with pytest.raises(ValueError): + method(api, "invalid-job") diff --git a/api/tests/unit_tests/controllers/console/datasets/test_external.py b/api/tests/unit_tests/controllers/console/datasets/test_external.py new file mode 100644 index 0000000000..161d0c41e8 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_external.py @@ -0,0 +1,399 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound + +import services +from controllers.console import console_ns +from controllers.console.datasets.error import DatasetNameDuplicateError +from controllers.console.datasets.external import ( + BedrockRetrievalApi, + ExternalApiTemplateApi, + ExternalApiTemplateListApi, + ExternalDatasetCreateApi, + ExternalKnowledgeHitTestingApi, +) +from services.dataset_service import DatasetService +from services.external_knowledge_service import ExternalDatasetService +from services.hit_testing_service import HitTestingService +from services.knowledge_service import ExternalDatasetTestService + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_external_dataset") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + user.is_dataset_editor = True + user.has_edit_permission = True + user.is_dataset_operator = True + return user + + +@pytest.fixture(autouse=True) +def mock_auth(mocker, current_user): + mocker.patch( + "controllers.console.datasets.external.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ) + + +class TestExternalApiTemplateListApi: + def test_get_success(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + api_item = MagicMock() + api_item.to_dict.return_value = {"id": "1"} + + with ( + app.test_request_context("/?page=1&limit=20"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_apis", + return_value=([api_item], 1), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 1 + assert resp["data"][0]["id"] == "1" + + def test_post_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + ): + with pytest.raises(Forbidden): + method(api) + + def test_post_duplicate_name(self, app): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "x", "settings": {"k": "v"}} + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(ExternalDatasetService, "validate_api_list"), + patch.object( + ExternalDatasetService, + "create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError(), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + +class TestExternalApiTemplateApi: + def test_get_not_found(self, app): + api = ExternalApiTemplateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + ExternalDatasetService, + "get_external_knowledge_api", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "api-id") + + def test_delete_forbidden(self, app, current_user): + current_user.has_edit_permission = False + current_user.is_dataset_operator = False + + api = ExternalApiTemplateApi() + method = unwrap(api.delete) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "api-id") + + +class TestExternalDatasetCreateApi: + def test_create_success(self, app): + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + dataset = MagicMock() + + dataset.embedding_available = False + dataset.built_in_field_enabled = False + dataset.is_published = False + dataset.enable_api = False + dataset.enable_qa = False + dataset.enable_vector_store = False + dataset.vector_store_setting = None + dataset.is_multimodal = False + + dataset.retrieval_model_dict = {} + dataset.tags = [] + dataset.external_knowledge_info = None + dataset.external_retrieval_model = None + dataset.doc_metadata = [] + dataset.icon_info = None + + dataset.summary_index_setting = MagicMock() + dataset.summary_index_setting.enable = False + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetService, + "create_external_dataset", + return_value=dataset, + ), + ): + _, status = method(api) + + assert status == 201 + + def test_create_forbidden(self, app, current_user): + current_user.is_dataset_editor = False + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + payload = { + "external_knowledge_api_id": "api", + "external_knowledge_id": "kid", + "name": "dataset", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + ): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApi: + def test_hit_testing_dataset_not_found(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "dataset-id") + + def test_hit_testing_success(self, app): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = {"query": "hello"} + + dataset = MagicMock() + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + HitTestingService, + "external_retrieve", + return_value={"ok": True}, + ), + ): + resp = method(api, "dataset-id") + + assert resp["ok"] is True + + +class TestBedrockRetrievalApi: + def test_bedrock_retrieval(self, app): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "hello", + "knowledge_id": "kid", + } + + with ( + app.test_request_context("/"), + patch.object(type(console_ns), "payload", new_callable=PropertyMock, return_value=payload), + patch.object( + ExternalDatasetTestService, + "knowledge_retrieval", + return_value={"ok": True}, + ), + ): + resp, status = method() + + assert status == 200 + assert resp["ok"] is True + + +class TestExternalApiTemplateListApiAdvanced: + def test_post_duplicate_name_error(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.post) + + payload = {"name": "duplicate_api", "settings": {"key": "value"}} + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch("controllers.console.datasets.external.ExternalDatasetService.validate_api_list"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.create_external_knowledge_api", + side_effect=services.errors.dataset.DatasetNameDuplicateError("Duplicate"), + ), + ): + with pytest.raises(DatasetNameDuplicateError): + method(api) + + def test_get_with_pagination(self, app, mock_auth, current_user): + api = ExternalApiTemplateListApi() + method = unwrap(api.get) + + templates = [MagicMock(id=f"api-{i}") for i in range(3)] + + with ( + app.test_request_context("/?page=1&limit=20"), + patch( + "controllers.console.datasets.external.ExternalDatasetService.get_external_knowledge_apis", + return_value=(templates, 25), + ), + ): + resp, status = method(api) + + assert status == 200 + assert resp["total"] == 25 + assert len(resp["data"]) == 3 + + +class TestExternalDatasetCreateApiAdvanced: + def test_create_forbidden(self, app, mock_auth, current_user): + """Test creating external dataset without permission""" + api = ExternalDatasetCreateApi() + method = unwrap(api.post) + + current_user.is_dataset_editor = False + + payload = { + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "ek-1", + "name": "new_dataset", + "description": "A dataset", + } + + with app.test_request_context("/", json=payload), patch.object(type(console_ns), "payload", payload): + with pytest.raises(Forbidden): + method(api) + + +class TestExternalKnowledgeHitTestingApiAdvanced: + def test_hit_testing_dataset_not_found(self, app, mock_auth, current_user): + """Test hit testing on non-existent dataset""" + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test query", + "external_retrieval_model": None, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, "ds-1") + + def test_hit_testing_with_custom_retrieval_model(self, app, mock_auth, current_user): + api = ExternalKnowledgeHitTestingApi() + method = unwrap(api.post) + + dataset = MagicMock() + payload = { + "query": "test query", + "external_retrieval_model": {"type": "bm25"}, + "metadata_filtering_conditions": {"status": "active"}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.DatasetService.get_dataset", + return_value=dataset, + ), + patch("controllers.console.datasets.external.DatasetService.check_dataset_permission"), + patch( + "controllers.console.datasets.external.HitTestingService.external_retrieve", + return_value={"results": []}, + ), + ): + resp = method(api, "ds-1") + + assert resp["results"] == [] + + +class TestBedrockRetrievalApiAdvanced: + def test_bedrock_retrieval_with_invalid_setting(self, app, mock_auth, current_user): + api = BedrockRetrievalApi() + method = unwrap(api.post) + + payload = { + "retrieval_setting": {}, + "query": "test", + "knowledge_id": "k-1", + } + + with ( + app.test_request_context("/", json=payload), + patch.object(type(console_ns), "payload", payload), + patch( + "controllers.console.datasets.external.ExternalDatasetTestService.knowledge_retrieval", + side_effect=ValueError("Invalid settings"), + ), + ): + with pytest.raises(ValueError): + method() diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py new file mode 100644 index 0000000000..55fb038156 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing.py @@ -0,0 +1,160 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.hit_testing import HitTestingApi +from controllers.console.datasets.hit_testing_base import HitTestingPayload + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_hit_testing") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass all decorators on the API method.""" + mocker.patch( + "controllers.console.datasets.hit_testing.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.login_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.account_initialization_required", + return_value=lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.hit_testing.cloud_edition_billing_rate_limit_check", + return_value=lambda *_: (lambda f: f), + ) + + +class TestHitTestingApi: + def test_hit_testing_success(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "what is vector search", + "top_k": 3, + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + ), + patch.object( + HitTestingApi, + "perform_hit_testing", + return_value={"query": "what is vector search", "records": []}, + ), + ): + result = method(api, dataset_id) + + assert "query" in result + assert "records" in result + assert result["records"] == [] + + def test_hit_testing_dataset_not_found(self, app, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "test", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + side_effect=NotFound("Dataset not found"), + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + def test_hit_testing_invalid_args(self, app, dataset, dataset_id): + api = HitTestingApi() + method = unwrap(api.post) + + payload = { + "query": "", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch.object( + HitTestingPayload, + "model_validate", + return_value=MagicMock(model_dump=lambda **_: payload), + ), + patch.object( + HitTestingApi, + "get_and_validate_dataset", + return_value=dataset, + ), + patch.object( + HitTestingApi, + "hit_testing_args_check", + side_effect=ValueError("Invalid parameters"), + ), + ): + with pytest.raises(ValueError, match="Invalid parameters"): + method(api, dataset_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py new file mode 100644 index 0000000000..e7ae37ae45 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -0,0 +1,207 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound + +import services +from controllers.console.app.error import ( + CompletionRequestError, + ProviderModelCurrentlyNotSupportError, + ProviderNotInitializeError, + ProviderQuotaExceededError, +) +from controllers.console.datasets.error import DatasetNotInitializedError +from controllers.console.datasets.hit_testing_base import ( + DatasetsHitTestingBase, +) +from core.errors.error import ( + LLMBadRequestError, + ModelCurrentlyNotSupportError, + ProviderTokenNotInitError, + QuotaExceededError, +) +from dify_graph.model_runtime.errors.invoke import InvokeError +from models.account import Account +from services.dataset_service import DatasetService +from services.hit_testing_service import HitTestingService + + +@pytest.fixture +def account(): + acc = MagicMock(spec=Account) + return acc + + +@pytest.fixture(autouse=True) +def patch_current_user(mocker, account): + """Patch current_user to a valid Account.""" + mocker.patch( + "controllers.console.datasets.hit_testing_base.current_user", + account, + ) + + +@pytest.fixture +def dataset(): + return MagicMock(id="dataset-1") + + +class TestGetAndValidateDataset: + def test_success(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + ): + result = DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + assert result == dataset + + def test_dataset_not_found(self): + with patch.object( + DatasetService, + "get_dataset", + return_value=None, + ): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + def test_permission_denied(self, dataset): + with ( + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + side_effect=services.errors.account.NoPermissionError("no access"), + ), + ): + with pytest.raises(Forbidden, match="no access"): + DatasetsHitTestingBase.get_and_validate_dataset("dataset-1") + + +class TestHitTestingArgsCheck: + def test_args_check_called(self): + args = {"query": "test"} + + with patch.object( + HitTestingService, + "hit_testing_args_check", + ) as check_mock: + DatasetsHitTestingBase.hit_testing_args_check(args) + + check_mock.assert_called_once_with(args) + + +class TestParseArgs: + def test_parse_args_success(self): + payload = {"query": "hello"} + + result = DatasetsHitTestingBase.parse_args(payload) + + assert result["query"] == "hello" + + def test_parse_args_invalid(self): + payload = {"query": "x" * 300} + + with pytest.raises(ValueError): + DatasetsHitTestingBase.parse_args(payload) + + +class TestPerformHitTesting: + def test_success(self, dataset): + response = { + "query": "hello", + "records": [], + } + + with patch.object( + HitTestingService, + "retrieve", + return_value=response, + ): + result = DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + assert result["query"] == "hello" + assert result["records"] == [] + + def test_index_not_initialized(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=services.errors.index.IndexNotInitializedError(), + ): + with pytest.raises(DatasetNotInitializedError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_provider_token_not_init(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ProviderTokenNotInitError("token missing"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_quota_exceeded(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=QuotaExceededError(), + ): + with pytest.raises(ProviderQuotaExceededError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_model_not_supported(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ModelCurrentlyNotSupportError(), + ): + with pytest.raises(ProviderModelCurrentlyNotSupportError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_llm_bad_request(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=LLMBadRequestError("bad request"), + ): + with pytest.raises(ProviderNotInitializeError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_invoke_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=InvokeError("invoke failed"), + ): + with pytest.raises(CompletionRequestError): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_value_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=ValueError("bad args"), + ): + with pytest.raises(ValueError, match="bad args"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) + + def test_unexpected_error(self, dataset): + with patch.object( + HitTestingService, + "retrieve", + side_effect=Exception("boom"), + ): + with pytest.raises(InternalServerError, match="boom"): + DatasetsHitTestingBase.perform_hit_testing(dataset, {"query": "hello"}) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_metadata.py b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py new file mode 100644 index 0000000000..de834c2d4d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_metadata.py @@ -0,0 +1,362 @@ +import uuid +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest +from flask import Flask +from werkzeug.exceptions import NotFound + +from controllers.console import console_ns +from controllers.console.datasets.metadata import ( + DatasetMetadataApi, + DatasetMetadataBuiltInFieldActionApi, + DatasetMetadataBuiltInFieldApi, + DatasetMetadataCreateApi, + DocumentMetadataEditApi, +) +from services.dataset_service import DatasetService +from services.entities.knowledge_entities.knowledge_entities import ( + MetadataArgs, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_dataset_metadata") + app.config["TESTING"] = True + return app + + +@pytest.fixture +def current_user(): + user = MagicMock() + user.id = "user-1" + return user + + +@pytest.fixture +def dataset(): + ds = MagicMock() + ds.id = "dataset-1" + return ds + + +@pytest.fixture +def dataset_id(): + return uuid.uuid4() + + +@pytest.fixture +def metadata_id(): + return uuid.uuid4() + + +@pytest.fixture(autouse=True) +def bypass_decorators(mocker): + """Bypass setup/login/license decorators.""" + mocker.patch( + "controllers.console.datasets.metadata.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.account_initialization_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.metadata.enterprise_license_required", + lambda f: f, + ) + + +class TestDatasetMetadataCreateApi: + def test_create_metadata_success(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + payload = {"name": "author"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "create_metadata", + return_value={"id": "m1", "name": "author"}, + ), + ): + result, status = method(api, dataset_id) + + assert status == 201 + assert result["name"] == "author" + + def test_create_metadata_dataset_not_found(self, app, current_user, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.post) + + valid_payload = { + "type": "string", + "name": "author", + } + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=valid_payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + MetadataArgs, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound, match="Dataset not found"): + method(api, dataset_id) + + +class TestDatasetMetadataGetApi: + def test_get_metadata_success(self, app, dataset, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + MetadataService, + "get_dataset_metadatas", + return_value=[{"id": "m1"}], + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert isinstance(result, list) + + def test_get_metadata_dataset_not_found(self, app, dataset_id): + api = DatasetMetadataCreateApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + DatasetService, + "get_dataset", + return_value=None, + ), + ): + with pytest.raises(NotFound): + method(api, dataset_id) + + +class TestDatasetMetadataApi: + def test_update_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.patch) + + payload = {"name": "updated-name"} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "update_metadata_name", + return_value={"id": "m1", "name": "updated-name"}, + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 200 + assert result["name"] == "updated-name" + + def test_delete_metadata_success(self, app, current_user, dataset, dataset_id, metadata_id): + api = DatasetMetadataApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "delete_metadata", + ), + ): + result, status = method(api, dataset_id, metadata_id) + + assert status == 204 + assert result["result"] == "success" + + +class TestDatasetMetadataBuiltInFieldApi: + def test_get_built_in_fields(self, app): + api = DatasetMetadataBuiltInFieldApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch.object( + MetadataService, + "get_built_in_fields", + return_value=["title", "source"], + ), + ): + result, status = method(api) + + assert status == 200 + assert result["fields"] == ["title", "source"] + + +class TestDatasetMetadataBuiltInFieldActionApi: + def test_enable_built_in_field(self, app, current_user, dataset, dataset_id): + api = DatasetMetadataBuiltInFieldActionApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataService, + "enable_built_in_field", + ), + ): + result, status = method(api, dataset_id, "enable") + + assert status == 200 + assert result["result"] == "success" + + +class TestDocumentMetadataEditApi: + def test_update_document_metadata_success(self, app, current_user, dataset, dataset_id): + api = DocumentMetadataEditApi() + method = unwrap(api.post) + + payload = {"operation": "add", "metadata": {}} + + with ( + app.test_request_context("/"), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.datasets.metadata.current_account_with_tenant", + return_value=(current_user, "tenant-1"), + ), + patch.object( + DatasetService, + "get_dataset", + return_value=dataset, + ), + patch.object( + DatasetService, + "check_dataset_permission", + ), + patch.object( + MetadataOperationData, + "model_validate", + return_value=MagicMock(), + ), + patch.object( + MetadataService, + "update_documents_metadata", + ), + ): + result, status = method(api, dataset_id) + + assert status == 200 + assert result["result"] == "success" diff --git a/api/tests/unit_tests/controllers/console/datasets/test_website.py b/api/tests/unit_tests/controllers/console/datasets/test_website.py new file mode 100644 index 0000000000..9f0da6e76f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_website.py @@ -0,0 +1,233 @@ +from unittest.mock import Mock, PropertyMock, patch + +import pytest +from flask import Flask + +from controllers.console import console_ns +from controllers.console.datasets.error import WebsiteCrawlError +from controllers.console.datasets.website import ( + WebsiteCrawlApi, + WebsiteCrawlStatusApi, +) +from services.website_service import ( + WebsiteCrawlApiRequest, + WebsiteCrawlStatusApiRequest, + WebsiteService, +) + + +def unwrap(func): + """Recursively unwrap decorated functions.""" + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def app(): + app = Flask("test_website_crawl") + app.config["TESTING"] = True + return app + + +@pytest.fixture(autouse=True) +def bypass_auth_and_setup(mocker): + """Bypass setup/login/account decorators.""" + mocker.patch( + "controllers.console.datasets.website.login_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.setup_required", + lambda f: f, + ) + mocker.patch( + "controllers.console.datasets.website.account_initialization_required", + lambda f: f, + ) + + +class TestWebsiteCrawlApi: + def test_crawl_success(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {"depth": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + return_value={"job_id": "job-1"}, + ) + + result, status = method(api) + + assert status == 200 + assert result["job_id"] == "job-1" + + def test_crawl_invalid_payload(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "bad-url", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + side_effect=ValueError("invalid payload"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid payload"): + method(api) + + def test_crawl_service_error(self, app, mocker): + api = WebsiteCrawlApi() + method = unwrap(api.post) + + payload = { + "provider": "firecrawl", + "url": "https://example.com", + "options": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + ): + mock_request = Mock(spec=WebsiteCrawlApiRequest) + mocker.patch.object( + WebsiteCrawlApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "crawl_url", + side_effect=Exception("crawl failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="crawl failed"): + method(api) + + +class TestWebsiteCrawlStatusApi: + def test_get_status_success(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + return_value={"status": "completed"}, + ) + + result, status = method(api, job_id) + + assert status == 200 + assert result["status"] == "completed" + + def test_get_status_invalid_provider(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + side_effect=ValueError("invalid provider"), + ) + + with pytest.raises(WebsiteCrawlError, match="invalid provider"): + method(api, job_id) + + def test_get_status_service_error(self, app, mocker): + api = WebsiteCrawlStatusApi() + method = unwrap(api.get) + + job_id = "job-123" + args = {"provider": "firecrawl"} + + with app.test_request_context("/?provider=firecrawl"): + mocker.patch( + "controllers.console.datasets.website.request.args.to_dict", + return_value=args, + ) + + mock_request = Mock(spec=WebsiteCrawlStatusApiRequest) + mocker.patch.object( + WebsiteCrawlStatusApiRequest, + "from_args", + return_value=mock_request, + ) + + mocker.patch.object( + WebsiteService, + "get_crawl_status_typed", + side_effect=Exception("status lookup failed"), + ) + + with pytest.raises(WebsiteCrawlError, match="status lookup failed"): + method(api, job_id) diff --git a/api/tests/unit_tests/controllers/console/datasets/test_wraps.py b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py new file mode 100644 index 0000000000..90f00711c1 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/datasets/test_wraps.py @@ -0,0 +1,117 @@ +from unittest.mock import Mock + +import pytest + +from controllers.console.datasets.error import PipelineNotFoundError +from controllers.console.datasets.wraps import get_rag_pipeline +from models.dataset import Pipeline + + +class TestGetRagPipeline: + def test_missing_pipeline_id(self): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + with pytest.raises(ValueError, match="missing pipeline_id"): + dummy_view() + + def test_pipeline_not_found(self, mocker): + @get_rag_pipeline + def dummy_view(**kwargs): + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = None + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + with pytest.raises(PipelineNotFoundError): + dummy_view(pipeline_id="pipeline-1") + + def test_pipeline_found_and_injected(self, mocker): + pipeline = Mock(spec=Pipeline) + pipeline.id = "pipeline-1" + pipeline.tenant_id = "tenant-1" + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result is pipeline + + def test_pipeline_id_removed_from_kwargs(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + assert "pipeline_id" not in kwargs + return "ok" + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + mock_query = Mock() + mock_query.where.return_value.first.return_value = pipeline + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id="pipeline-1") + + assert result == "ok" + + def test_pipeline_id_cast_to_string(self, mocker): + pipeline = Mock(spec=Pipeline) + + @get_rag_pipeline + def dummy_view(**kwargs): + return kwargs["pipeline"] + + mocker.patch( + "controllers.console.datasets.wraps.current_account_with_tenant", + return_value=(Mock(), "tenant-1"), + ) + + def where_side_effect(*args, **kwargs): + assert args[0].right.value == "123" + return Mock(first=lambda: pipeline) + + mock_query = Mock() + mock_query.where.side_effect = where_side_effect + + mocker.patch( + "controllers.console.datasets.wraps.db.session.query", + return_value=mock_query, + ) + + result = dummy_view(pipeline_id=123) + + assert result is pipeline From 497feac48e92bbf25087677cd01224855ce77254 Mon Sep 17 00:00:00 2001 From: rajatagarwal-oss Date: Mon, 9 Mar 2026 14:37:40 +0530 Subject: [PATCH 2/3] test: unit test case for controllers.console.workspace module (#32181) --- .../console/workspace/test_accounts.py | 341 ++++++ .../console/workspace/test_agent_providers.py | 139 +++ .../console/workspace/test_endpoint.py | 305 +++++ .../console/workspace/test_members.py | 607 ++++++++++ .../console/workspace/test_model_providers.py | 388 +++++++ .../console/workspace/test_models.py | 447 ++++++++ .../console/workspace/test_plugin.py | 1019 +++++++++++++++++ .../console/workspace/test_tool_provider.py | 643 ++++++++++- .../workspace/test_trigger_providers.py | 558 +++++++++ .../console/workspace/test_workspace.py | 605 ++++++++++ .../console/workspace/test_workspace_wraps.py | 142 +++ 11 files changed, 5190 insertions(+), 4 deletions(-) create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_accounts.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_endpoint.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_members.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_model_providers.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_models.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_plugin.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_workspace.py create mode 100644 api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py diff --git a/api/tests/unit_tests/controllers/console/workspace/test_accounts.py b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py new file mode 100644 index 0000000000..00d322fdea --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_accounts.py @@ -0,0 +1,341 @@ +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from controllers.console import console_ns +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailCodeError, +) +from controllers.console.error import AccountInFreezeError +from controllers.console.workspace.account import ( + AccountAvatarApi, + AccountDeleteApi, + AccountDeleteVerifyApi, + AccountInitApi, + AccountIntegrateApi, + AccountInterfaceLanguageApi, + AccountInterfaceThemeApi, + AccountNameApi, + AccountPasswordApi, + AccountProfileApi, + AccountTimezoneApi, + ChangeEmailCheckApi, + ChangeEmailResetApi, + CheckEmailUnique, +) +from controllers.console.workspace.error import ( + AccountAlreadyInitedError, + CurrentPasswordIncorrectError, + InvalidAccountDeletionCodeError, +) +from services.errors.account import CurrentPasswordIncorrectError as ServicePwdError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAccountInitApi: + def test_init_success(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="inactive") + payload = { + "interface_language": "en-US", + "timezone": "UTC", + "invitation_code": "code123", + } + + with ( + app.test_request_context("/account/init", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.commit", return_value=None), + patch("controllers.console.workspace.account.dify_config.EDITION", "CLOUD"), + patch("controllers.console.workspace.account.db.session.query") as query_mock, + ): + query_mock.return_value.where.return_value.first.return_value = MagicMock(status="unused") + resp = method(api) + + assert resp["result"] == "success" + + def test_init_already_initialized(self, app): + api = AccountInitApi() + method = unwrap(api.post) + + account = MagicMock(status="active") + + with ( + app.test_request_context("/account/init"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + ): + with pytest.raises(AccountAlreadyInitedError): + method(api) + + +class TestAccountProfileApi: + def test_get_profile_success(self, app): + api = AccountProfileApi() + method = unwrap(api.get) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/account/profile"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountUpdateApis: + @pytest.mark.parametrize( + ("api_cls", "payload"), + [ + (AccountNameApi, {"name": "test"}), + (AccountAvatarApi, {"avatar": "img.png"}), + (AccountInterfaceLanguageApi, {"interface_language": "en-US"}), + (AccountInterfaceThemeApi, {"interface_theme": "dark"}), + (AccountTimezoneApi, {"timezone": "UTC"}), + ], + ) + def test_update_success(self, app, api_cls, payload): + api = api_cls() + method = unwrap(api.post) + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account", return_value=user), + ): + result = method(api) + + assert result["id"] == "u1" + + +class TestAccountPasswordApi: + def test_password_success(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "old", + "new_password": "new123", + "repeat_new_password": "new123", + } + + user = MagicMock() + user.id = "u1" + user.name = "John" + user.email = "john@test.com" + user.avatar = "avatar.png" + user.interface_language = "en-US" + user.interface_theme = "light" + user.timezone = "UTC" + user.last_login_ip = "127.0.0.1" + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.account.AccountService.update_account_password", return_value=None), + ): + result = method(api) + + assert result["id"] == "u1" + + def test_password_wrong_current(self, app): + api = AccountPasswordApi() + method = unwrap(api.post) + + payload = { + "password": "bad", + "new_password": "new123", + "repeat_new_password": "new123", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.update_account_password", + side_effect=ServicePwdError(), + ), + ): + with pytest.raises(CurrentPasswordIncorrectError): + method(api) + + +class TestAccountIntegrateApi: + def test_get_integrates(self, app): + api = AccountIntegrateApi() + method = unwrap(api.get) + + account = MagicMock(id="acc1") + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.account.current_account_with_tenant", return_value=(account, "t1")), + patch("controllers.console.workspace.account.db.session.scalars") as scalars_mock, + ): + scalars_mock.return_value.all.return_value = [] + result = method(api) + + assert "data" in result + assert len(result["data"]) == 2 + + +class TestAccountDeleteApi: + def test_delete_verify_success(self, app): + api = AccountDeleteVerifyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.generate_account_deletion_verification_code", + return_value=("token", "1234"), + ), + patch( + "controllers.console.workspace.account.AccountService.send_account_deletion_verification_email", + return_value=None, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_delete_invalid_code(self, app): + api = AccountDeleteApi() + method = unwrap(api.post) + + payload = {"token": "t", "code": "x"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.account.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.account.AccountService.verify_account_deletion_code", + return_value=False, + ), + ): + with pytest.raises(InvalidAccountDeletionCodeError): + method(api) + + +class TestChangeEmailApis: + def test_check_email_code_invalid(self, app): + api = ChangeEmailCheckApi() + method = unwrap(api.post) + + payload = {"email": "a@test.com", "code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch( + "controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.account.AccountService.get_change_email_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_reset_email_already_used(self, app): + api = ChangeEmailResetApi() + method = unwrap(api.post) + + payload = {"new_email": "x@test.com", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=False), + ): + with pytest.raises(EmailAlreadyInUseError): + method(api) + + +class TestCheckEmailUniqueApi: + def test_email_unique_success(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "ok@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=False), + patch("controllers.console.workspace.account.AccountService.check_email_unique", return_value=True), + ): + result = method(api) + + assert result["result"] == "success" + + def test_email_in_freeze(self, app): + api = CheckEmailUnique() + method = unwrap(api.post) + + payload = {"email": "x@test.com"} + + with ( + app.test_request_context("/", json=payload), + patch.object( + type(console_ns), + "payload", + new_callable=PropertyMock, + return_value=payload, + ), + patch("controllers.console.workspace.account.AccountService.is_account_in_freeze", return_value=True), + ): + with pytest.raises(AccountInFreezeError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py new file mode 100644 index 0000000000..b4e03f681d --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_agent_providers.py @@ -0,0 +1,139 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.error import AccountNotFound +from controllers.console.workspace.agent_providers import ( + AgentProviderApi, + AgentProviderListApi, +) + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestAgentProviderListApi: + def test_get_success(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + providers = [{"name": "openai"}, {"name": "anthropic"}] + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=providers, + ), + ): + result = method(api) + + assert result == providers + + def test_get_empty_list(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.list_agent_providers", + return_value=[], + ), + ): + result = method(api) + + assert result == [] + + def test_get_account_not_found(self, app): + api = AgentProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api) + + +class TestAgentProviderApi: + def test_get_success(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "openai" + provider_data = {"name": "openai", "models": ["gpt-4"]} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=provider_data, + ), + ): + result = method(api, provider_name) + + assert result == provider_data + + def test_get_provider_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + user = MagicMock(id="user1") + tenant_id = "tenant1" + provider_name = "unknown" + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + return_value=(user, tenant_id), + ), + patch( + "controllers.console.workspace.agent_providers.AgentService.get_agent_provider", + return_value=None, + ), + ): + result = method(api, provider_name) + + assert result is None + + def test_get_account_not_found(self, app): + api = AgentProviderApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.agent_providers.current_account_with_tenant", + side_effect=AccountNotFound(), + ), + ): + with pytest.raises(AccountNotFound): + method(api, "openai") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py new file mode 100644 index 0000000000..51f76af172 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_endpoint.py @@ -0,0 +1,305 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from controllers.console.workspace.endpoint import ( + EndpointCreateApi, + EndpointDeleteApi, + EndpointDisableApi, + EndpointEnableApi, + EndpointListApi, + EndpointListForSinglePluginApi, + EndpointUpdateApi, +) +from core.plugin.impl.exc import PluginPermissionDeniedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user_and_tenant(): + return MagicMock(id="u1"), "t1" + + +@pytest.fixture +def patch_current_account(user_and_tenant): + with patch( + "controllers.console.workspace.endpoint.current_account_with_tenant", + return_value=user_and_tenant, + ): + yield + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointCreateApi: + def test_create_success(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {"a": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.create_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_create_permission_denied(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "plugin-1", + "name": "endpoint", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.endpoint.EndpointService.create_endpoint", + side_effect=PluginPermissionDeniedError("denied"), + ), + ): + with pytest.raises(ValueError): + method(api) + + def test_create_validation_error(self, app): + api = EndpointCreateApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p1", + "name": "", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListApi: + def test_list_success(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.endpoint.EndpointService.list_endpoints", return_value=[{"id": "e1"}]), + ): + result = method(api) + + assert "endpoints" in result + assert len(result["endpoints"]) == 1 + + def test_list_invalid_query(self, app): + api = EndpointListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=0&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointListForSinglePluginApi: + def test_list_for_plugin_success(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10&plugin_id=p1"), + patch( + "controllers.console.workspace.endpoint.EndpointService.list_endpoints_for_single_plugin", + return_value=[{"id": "e1"}], + ), + ): + result = method(api) + + assert "endpoints" in result + + def test_list_for_plugin_missing_param(self, app): + api = EndpointListForSinglePluginApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + ): + with pytest.raises(ValueError): + method(api) + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDeleteApi: + def test_delete_success(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_delete_invalid_payload(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_delete_service_failure(self, app): + api = EndpointDeleteApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.delete_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointUpdateApi: + def test_update_success(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "new-name", + "settings": {"x": 1}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_update_validation_error(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1", "settings": {}} + + with ( + app.test_request_context("/", json=payload), + ): + with pytest.raises(ValueError): + method(api) + + def test_update_service_failure(self, app): + api = EndpointUpdateApi() + method = unwrap(api.post) + + payload = { + "endpoint_id": "e1", + "name": "n", + "settings": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.update_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointEnableApi: + def test_enable_success(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_enable_invalid_payload(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) + + def test_enable_service_failure(self, app): + api = EndpointEnableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.enable_endpoint", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +@pytest.mark.usefixtures("patch_current_account") +class TestEndpointDisableApi: + def test_disable_success(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + payload = {"endpoint_id": "e1"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.endpoint.EndpointService.disable_endpoint", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_disable_invalid_payload(self, app): + api = EndpointDisableApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={}), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_members.py b/api/tests/unit_tests/controllers/console/workspace/test_members.py new file mode 100644 index 0000000000..b6708d1f6f --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_members.py @@ -0,0 +1,607 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import HTTPException + +import services +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded +from controllers.console.workspace.members import ( + DatasetOperatorMemberListApi, + MemberCancelInviteApi, + MemberInviteEmailApi, + MemberListApi, + MemberUpdateRoleApi, + OwnerTransfer, + OwnerTransferCheckApi, + SendOwnerTransferEmailApi, +) +from services.errors.account import AccountAlreadyInTenantError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestMemberListApi: + def test_get_success(self, app): + api = MemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "m1" + member.name = "Member" + member.email = "member@test.com" + member.avatar = "avatar.png" + member.role = "admin" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.get_tenant_members", return_value=members), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = MemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestMemberInviteEmailApi: + def test_invite_success(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + "language": "en-US", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch("controllers.console.workspace.members.RegisterService.invite_new_member", return_value="token"), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert status == 201 + assert result["result"] == "success" + + def test_invite_limit_exceeded(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = False + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + ): + with pytest.raises(WorkspaceMembersLimitExceeded): + method(api) + + def test_invite_already_member(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=AccountAlreadyInTenantError(), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, status = method(api) + + assert result["invitation_results"][0]["status"] == "success" + + def test_invite_invalid_role(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + payload = { + "emails": ["a@test.com"], + "role": "owner", + } + + with app.test_request_context("/", json=payload): + result, status = method(api) + + assert status == 400 + assert result["code"] == "invalid-role" + + def test_invite_generic_exception(self, app): + api = MemberInviteEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + features = MagicMock() + features.workspace_members.is_available.return_value = True + + payload = { + "emails": ["a@test.com"], + "role": "normal", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.FeatureService.get_features", return_value=features), + patch( + "controllers.console.workspace.members.RegisterService.invite_new_member", + side_effect=Exception("boom"), + ), + patch("controllers.console.workspace.members.dify_config.CONSOLE_WEB_URL", "http://x"), + ): + result, _ = method(api) + + assert result["invitation_results"][0]["status"] == "failed" + + +class TestMemberCancelInviteApi: + def test_cancel_success(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch("controllers.console.workspace.members.TenantService.remove_member_from_tenant"), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 200 + assert result["result"] == "success" + + def test_cancel_not_found(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + ): + q.return_value.where.return_value.first.return_value = None + + with pytest.raises(HTTPException): + method(api, "x") + + def test_cancel_cannot_operate_self(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.CannotOperateSelfError("x"), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 400 + + def test_cancel_no_permission(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.NoPermissionError("x"), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 403 + + def test_cancel_member_not_in_tenant(self, app): + api = MemberCancelInviteApi() + method = unwrap(api.delete) + + tenant = MagicMock(id="t1") + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.query") as q, + patch( + "controllers.console.workspace.members.TenantService.remove_member_from_tenant", + side_effect=services.errors.account.MemberNotInTenantError(), + ), + ): + q.return_value.where.return_value.first.return_value = member + result, status = method(api, member.id) + + assert status == 404 + + +class TestMemberUpdateRoleApi: + def test_update_success(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.update_member_role"), + ): + result = method(api, "id") + + if isinstance(result, tuple): + result = result[0] + + assert result["result"] == "success" + + def test_update_invalid_role(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "invalid-role"} + + with app.test_request_context("/", json=payload): + result, status = method(api, "id") + + assert status == 400 + + def test_update_member_not_found(self, app): + api = MemberUpdateRoleApi() + method = unwrap(api.put) + + payload = {"role": "normal"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.members.current_account_with_tenant", + return_value=(MagicMock(current_tenant=MagicMock()), "t1"), + ), + patch("controllers.console.workspace.members.db.session.get", return_value=None), + ): + with pytest.raises(HTTPException): + method(api, "id") + + +class TestDatasetOperatorMemberListApi: + def test_get_success(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + member = MagicMock() + member.id = "op1" + member.name = "Operator" + member.email = "operator@test.com" + member.avatar = "avatar.png" + member.role = "operator" + member.status = "active" + members = [member] + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.members.TenantService.get_dataset_operator_members", return_value=members + ), + ): + result, status = method(api) + + assert status == 200 + assert len(result["accounts"]) == 1 + + def test_get_no_tenant(self, app): + api = DatasetOperatorMemberListApi() + method = unwrap(api.get) + + user = MagicMock(current_tenant=None) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(ValueError): + method(api) + + +class TestSendOwnerTransferEmailApi: + def test_send_success(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock(name="ws") + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.send_owner_transfer_email", return_value="token" + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_send_ip_limit(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + payload = {} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=True), + ): + with pytest.raises(EmailSendIpLimitError): + method(api) + + def test_send_not_owner(self, app): + api = SendOwnerTransferEmailApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/", json={}), + patch("controllers.console.workspace.members.extract_remote_ip", return_value="1.1.1.1"), + patch("controllers.console.workspace.members.AccountService.is_email_send_ip_limit", return_value=False), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=False), + ): + with pytest.raises(NotOwnerError): + method(api) + + +class TestOwnerTransferCheckApi: + def test_check_invalid_code(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com", "code": "y"}, + ), + ): + with pytest.raises(EmailCodeError): + method(api) + + def test_rate_limited(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=True, + ), + ): + with pytest.raises(OwnerTransferLimitError): + method(api) + + def test_invalid_token(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api) + + def test_invalid_email(self, app): + api = OwnerTransferCheckApi() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(email="a@test.com", current_tenant=tenant) + + payload = {"code": "x", "token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.is_owner_transfer_error_rate_limit", + return_value=False, + ), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "b@test.com", "code": "x"}, + ), + ): + with pytest.raises(InvalidEmailError): + method(api) + + +class TestOwnerTransferApi: + def test_transfer_self(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + ): + with pytest.raises(CannotTransferOwnerToSelfError): + method(api, "1") + + def test_invalid_token(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch("controllers.console.workspace.members.AccountService.get_owner_transfer_data", return_value=None), + ): + with pytest.raises(InvalidTokenError): + method(api, "2") + + def test_member_not_in_tenant(self, app): + api = OwnerTransfer() + method = unwrap(api.post) + + tenant = MagicMock() + user = MagicMock(id="1", email="a@test.com", current_tenant=tenant) + member = MagicMock() + + payload = {"token": "t"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.members.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.members.TenantService.is_owner", return_value=True), + patch( + "controllers.console.workspace.members.AccountService.get_owner_transfer_data", + return_value={"email": "a@test.com"}, + ), + patch("controllers.console.workspace.members.db.session.get", return_value=member), + patch("controllers.console.workspace.members.TenantService.is_member", return_value=False), + ): + with pytest.raises(MemberNotInTenantError): + method(api, "2") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py new file mode 100644 index 0000000000..af0c2c5594 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -0,0 +1,388 @@ +from unittest.mock import MagicMock, patch + +import pytest +from pydantic_core import ValidationError +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.model_providers import ( + ModelProviderCredentialApi, + ModelProviderCredentialSwitchApi, + ModelProviderIconApi, + ModelProviderListApi, + ModelProviderPaymentCheckoutUrlApi, + ModelProviderValidateApi, + PreferredProviderTypeUpdateApi, +) +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + +VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" +INVALID_UUID = "123" + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestModelProviderListApi: + def test_get_success(self, app): + api = ModelProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?model_type=llm"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_list", + return_value=[{"name": "openai"}], + ), + ): + result = method(api) + + assert "data" in result + + +class TestModelProviderCredentialApi: + def test_get_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={VALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_provider_credential", + return_value={"key": "value"}, + ), + ): + result = method(api, provider="openai") + + assert "credentials" in result + + def test_get_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context(f"/?credential_id={INVALID_UUID}"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_post_create_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}, "name": "test"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 201 + + def test_post_create_validation_error(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.create_provider_credential", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_put_update_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": VALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.update_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_put_invalid_uuid(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.put) + + payload = {"credential_id": INVALID_UUID, "credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + def test_delete_success(self, app): + api = ModelProviderCredentialApi() + method = unwrap(api.delete) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.remove_provider_credential", + return_value=None, + ), + ): + result, status = method(api, provider="openai") + + assert result["result"] == "success" + assert status == 204 + + +class TestModelProviderCredentialSwitchApi: + def test_switch_success(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": VALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_active_provider_credential", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_switch_invalid_uuid(self, app): + api = ModelProviderCredentialSwitchApi() + method = unwrap(api.post) + + payload = {"credential_id": INVALID_UUID} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderValidateApi: + def test_validate_success(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_validate_failure(self, app): + api = ModelProviderValidateApi() + method = unwrap(api.post) + + payload = {"credentials": {"a": "b"}} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.validate_provider_credentials", + side_effect=CredentialsValidateFailedError("bad"), + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "error" + + +class TestModelProviderIconApi: + def test_icon_success(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(b"123", "image/png"), + ), + ): + response = api.get("t1", "openai", "logo", "en") + + assert response.mimetype == "image/png" + + def test_icon_not_found(self, app): + api = ModelProviderIconApi() + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.get_model_provider_icon", + return_value=(None, None), + ), + ): + with pytest.raises(ValueError): + api.get("t1", "openai", "logo", "en") + + +class TestPreferredProviderTypeUpdateApi: + def test_update_success(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "custom"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.ModelProviderService.switch_preferred_provider", + return_value=None, + ), + ): + result = method(api, provider="openai") + + assert result["result"] == "success" + + def test_invalid_enum(self, app): + api = PreferredProviderTypeUpdateApi() + method = unwrap(api.post) + + payload = {"preferred_provider_type": "invalid"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + ): + with pytest.raises(ValidationError): + method(api, provider="openai") + + +class TestModelProviderPaymentCheckoutUrlApi: + def test_checkout_success(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + return_value=None, + ), + patch( + "controllers.console.workspace.model_providers.BillingService.get_model_provider_payment_link", + return_value={"url": "x"}, + ), + ): + result = method(api, provider="anthropic") + + assert "url" in result + + def test_invalid_provider(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(ValueError): + method(api, provider="openai") + + def test_permission_denied(self, app): + api = ModelProviderPaymentCheckoutUrlApi() + method = unwrap(api.get) + + user = MagicMock(id="u1", email="x@test.com") + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.model_providers.current_account_with_tenant", + return_value=(user, "tenant1"), + ), + patch( + "controllers.console.workspace.model_providers.BillingService.is_tenant_owner_or_admin", + side_effect=Forbidden(), + ), + ): + with pytest.raises(Forbidden): + method(api, provider="anthropic") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py new file mode 100644 index 0000000000..43b8e1ac2e --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -0,0 +1,447 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.console.workspace.models import ( + DefaultModelApi, + ModelProviderAvailableModelApi, + ModelProviderModelApi, + ModelProviderModelCredentialApi, + ModelProviderModelCredentialSwitchApi, + ModelProviderModelDisableApi, + ModelProviderModelEnableApi, + ModelProviderModelParameterRuleApi, + ModelProviderModelValidateApi, +) +from dify_graph.model_runtime.entities.model_entities import ModelType +from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestDefaultModelApi: + def test_get_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={"model_type": ModelType.LLM.value}, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_default_model_of_model_type.return_value = {"model": "gpt-4"} + + result = method(api) + + assert "data" in result + + def test_post_success(self, app: Flask): + api = DefaultModelApi() + method = unwrap(api.post) + + payload = { + "model_settings": [ + { + "model_type": ModelType.LLM.value, + "provider": "openai", + "model": "gpt-4", + } + ] + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api) + + assert result["result"] == "success" + + def test_get_returns_empty_when_no_default(self, app): + api = DefaultModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_default_model_of_model_type.return_value = None + + result = method(api) + + assert "data" in result + + +class TestModelProviderModelApi: + def test_get_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_post_models_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "load_balancing": { + "configs": [{"weight": 1}], + "enabled": True, + }, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + patch("controllers.console.workspace.models.ModelLoadBalancingService"), + ): + result, status = method(api, "openai") + + assert status == 200 + + def test_delete_model_success(self, app: Flask): + api = ModelProviderModelApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + def test_get_models_returns_empty(self, app): + api = ModelProviderModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_provider.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + +class TestModelProviderModelCredentialApi: + def test_get_credentials_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context( + "/", + query_string={ + "model": "gpt-4", + "model_type": ModelType.LLM.value, + }, + ), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as provider_service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb_service, + ): + provider_service.return_value.get_model_credential.return_value = { + "credentials": {}, + "current_credential_id": None, + "current_credential_name": None, + } + provider_service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb_service.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert "credentials" in result + + def test_create_credential_success(self, app: Flask): + api = ModelProviderModelCredentialApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 201 + + def test_get_empty_credentials(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt", "model_type": ModelType.LLM.value}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + patch("controllers.console.workspace.models.ModelLoadBalancingService") as lb, + ): + service.return_value.get_model_credential.return_value = None + service.return_value.provider_manager.get_provider_model_available_credentials.return_value = [] + lb.return_value.get_load_balancing_configs.return_value = (False, []) + + result = method(api, "openai") + + assert result["credentials"] == {} + + def test_delete_success(self, app): + api = ModelProviderModelCredentialApi() + method = unwrap(api.delete) + + payload = { + "model": "gpt", + "model_type": ModelType.LLM.value, + "credential_id": "123e4567-e89b-12d3-a456-426614174000", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result, status = method(api, "openai") + + assert status == 204 + + +class TestModelProviderModelCredentialSwitchApi: + def test_switch_success(self, app: Flask): + api = ModelProviderModelCredentialSwitchApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credential_id": "abc", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelEnableDisableApis: + def test_enable_model(self, app: Flask): + api = ModelProviderModelEnableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + def test_disable_model(self, app: Flask): + api = ModelProviderModelDisableApi() + method = unwrap(api.patch) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + +class TestModelProviderModelValidateApi: + def test_validate_success(self, app: Flask): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": "gpt-4", + "model_type": ModelType.LLM.value, + "credentials": {"key": "val"}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService"), + ): + result = method(api, "openai") + + assert result["result"] == "success" + + @pytest.mark.parametrize("model_name", ["gpt-4", "gpt"]) + def test_validate_failure(self, app: Flask, model_name: str): + api = ModelProviderModelValidateApi() + method = unwrap(api.post) + + payload = { + "model": model_name, + "model_type": ModelType.LLM.value, + "credentials": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.validate_model_credentials.side_effect = CredentialsValidateFailedError("invalid") + + result = method(api, "openai") + + assert result["result"] == "error" + + +class TestParameterAndAvailableModels: + def test_parameter_rules(self, app: Flask): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt-4"}), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert "data" in result + + def test_available_models(self, app: Flask): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.models.current_account_with_tenant", + return_value=(MagicMock(), "tenant1"), + ), + patch("controllers.console.workspace.models.ModelProviderService") as service_mock, + ): + service_mock.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert "data" in result + + def test_empty_rules(self, app): + api = ModelProviderModelParameterRuleApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/", query_string={"model": "gpt"}), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_model_parameter_rules.return_value = [] + + result = method(api, "openai") + + assert result["data"] == [] + + def test_no_models(self, app): + api = ModelProviderAvailableModelApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.models.current_account_with_tenant", return_value=(MagicMock(), "t1")), + patch("controllers.console.workspace.models.ModelProviderService") as service, + ): + service.return_value.get_models_by_model_type.return_value = [] + + result = method(api, ModelType.LLM.value) + + assert result["data"] == [] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py new file mode 100644 index 0000000000..f6db55db5b --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -0,0 +1,1019 @@ +import io +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace.plugin import ( + PluginAssetApi, + PluginAutoUpgradeExcludePluginApi, + PluginChangePermissionApi, + PluginChangePreferencesApi, + PluginDebuggingKeyApi, + PluginDeleteAllInstallTaskItemsApi, + PluginDeleteInstallTaskApi, + PluginDeleteInstallTaskItemApi, + PluginFetchDynamicSelectOptionsApi, + PluginFetchDynamicSelectOptionsWithCredentialsApi, + PluginFetchInstallTaskApi, + PluginFetchInstallTasksApi, + PluginFetchManifestApi, + PluginFetchMarketplacePkgApi, + PluginFetchPermissionApi, + PluginFetchPreferencesApi, + PluginIconApi, + PluginInstallFromGithubApi, + PluginInstallFromMarketplaceApi, + PluginInstallFromPkgApi, + PluginListApi, + PluginListInstallationsFromIdsApi, + PluginListLatestVersionsApi, + PluginReadmeApi, + PluginUninstallApi, + PluginUpgradeFromGithubApi, + PluginUpgradeFromMarketplaceApi, + PluginUploadFromBundleApi, + PluginUploadFromGithubApi, + PluginUploadFromPkgApi, +) +from core.plugin.impl.exc import PluginDaemonClientSideError +from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +@pytest.fixture +def user(): + u = MagicMock() + u.id = "u1" + u.is_admin_or_owner = True + return u + + +@pytest.fixture +def tenant(): + return "t1" + + +class TestPluginListLatestVersionsApi: + def test_success(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", return_value={"p1": "1.0"} + ), + ): + result = method(api) + + assert "versions" in result + + def test_daemon_error(self, app): + api = PluginListLatestVersionsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.plugin.PluginService.list_latest_versions", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDebuggingKeyApi: + def test_debugging_key_success(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.get_debugging_key", return_value="k"), + ): + result = method(api) + + assert result["key"] == "k" + + def test_debugging_key_error(self, app): + api = PluginDebuggingKeyApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.get_debugging_key", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginListApi: + def test_plugin_list(self, app): + api = PluginListApi() + method = unwrap(api.get) + + mock_list = MagicMock(list=[{"id": 1}], total=1) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.list_with_total", return_value=mock_list), + ): + result = method(api) + + assert result["total"] == 1 + + +class TestPluginIconApi: + def test_plugin_icon(self, app): + api = PluginIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?tenant_id=t1&filename=a.png"), + patch("controllers.console.workspace.plugin.PluginService.get_asset", return_value=(b"x", "image/png")), + ): + response = method(api) + + assert response.mimetype == "image/png" + + +class TestPluginAssetApi: + def test_plugin_asset(self, app): + api = PluginAssetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p&file_name=a.bin"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.extract_asset", return_value=b"x"), + ): + response = method(api) + + assert response.mimetype == "application/octet-stream" + + +class TestPluginUploadFromPkgApi: + def test_upload_pkg_success(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_pkg", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_upload_pkg_too_large(self, app): + api = PluginUploadFromPkgApi() + method = unwrap(api.post) + + data = { + "pkg": (io.BytesIO(b"x"), "test.pkg"), + } + + with ( + app.test_request_context("/", data=data, content_type="multipart/form-data"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromPkgApi: + def test_install_from_pkg(self, app): + api = PluginInstallFromPkgApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_local_pkg", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + +class TestPluginUninstallApi: + def test_uninstall(self, app): + api = PluginUninstallApi() + method = unwrap(api.post) + + payload = {"plugin_installation_id": "x"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.uninstall", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginChangePermissionApi: + def test_change_permission_forbidden(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=False) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + ): + with pytest.raises(Forbidden): + method(api) + + def test_change_permission_success(self, app): + api = PluginChangePermissionApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + +class TestPluginFetchPermissionApi: + def test_fetch_permission_default(self, app): + api = PluginFetchPermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=None), + ): + result = method(api) + + assert result["install_permission"] is not None + + +class TestPluginFetchDynamicSelectOptionsApi: + def test_fetch_dynamic_options(self, app, user): + api = PluginFetchDynamicSelectOptionsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_id=p&provider=x&action=y¶meter=z&provider_type=tool"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options", + return_value=[1, 2], + ), + ): + result = method(api) + + assert result["options"] == [1, 2] + + +class TestPluginReadmeApi: + def test_fetch_readme(self, app): + api = PluginReadmeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_readme", return_value="readme"), + ): + result = method(api) + + assert result["readme"] == "readme" + + +class TestPluginListInstallationsFromIdsApi: + def test_success(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1", "p2"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + return_value=[{"id": "p1"}], + ), + ): + result = method(api) + + assert "plugins" in result + + def test_daemon_error(self, app): + api = PluginListInstallationsFromIdsApi() + method = unwrap(api.post) + + payload = {"plugin_ids": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.list_installations_from_ids", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromGithubApi: + def test_success(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", return_value={"ok": True} + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUploadFromGithubApi() + method = unwrap(api.post) + + payload = {"repo": "r", "version": "v", "package": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upload_pkg_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUploadFromBundleApi: + def test_success(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.upload_bundle", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_too_large(self, app): + api = PluginUploadFromBundleApi() + method = unwrap(api.post) + + file = FileStorage( + stream=io.BytesIO(b"x"), + filename="test.bundle", + content_type="application/octet-stream", + ) + + with ( + app.test_request_context( + "/", + data={"bundle": file}, + content_type="multipart/form-data", + ), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromGithubApi: + def test_success(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.install_from_github", return_value={"ok": True}), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromGithubApi() + method = unwrap(api.post) + + payload = { + "plugin_unique_identifier": "p", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginInstallFromMarketplaceApi: + def test_success(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginInstallFromMarketplaceApi() + method = unwrap(api.post) + + payload = {"plugin_unique_identifiers": ["p1"]} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.install_from_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchMarketplacePkgApi: + def test_success(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", return_value={"m": 1}), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchMarketplacePkgApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_marketplace_pkg", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchManifestApi: + def test_success(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + manifest = MagicMock() + manifest.model_dump.return_value = {"x": 1} + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", return_value=manifest), + ): + result = method(api) + + assert "manifest" in result + + def test_daemon_error(self, app): + api = PluginFetchManifestApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?plugin_unique_identifier=p"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_plugin_manifest", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTasksApi: + def test_success(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_tasks", return_value=[{"id": 1}]), + ): + result = method(api) + + assert "tasks" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTasksApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?page=1&page_size=10"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_tasks", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchInstallTaskApi: + def test_success(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.fetch_install_task", return_value={"id": "x"}), + ): + result = method(api, "x") + + assert "task" in result + + def test_daemon_error(self, app): + api = PluginFetchInstallTaskApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.fetch_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteInstallTaskApi: + def test_success(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task", return_value=True), + ): + result = method(api, "x") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "t") + + +class TestPluginDeleteAllInstallTaskItemsApi: + def test_success(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", return_value=True + ), + ): + result = method(api) + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteAllInstallTaskItemsApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_all_install_task_items", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginDeleteInstallTaskItemApi: + def test_success(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginService.delete_install_task_item", return_value=True), + ): + result = method(api, "task1", "item1") + + assert result["success"] is True + + def test_daemon_error(self, app): + api = PluginDeleteInstallTaskItemApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.delete_install_task_item", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api, "task1", "item1") + + +class TestPluginUpgradeFromMarketplaceApi: + def test_success(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromMarketplaceApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_marketplace", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginUpgradeFromGithubApi: + def test_success(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + return_value={"ok": True}, + ), + ): + result = method(api) + + assert result["ok"] is True + + def test_daemon_error(self, app): + api = PluginUpgradeFromGithubApi() + method = unwrap(api.post) + + payload = { + "original_plugin_unique_identifier": "p1", + "new_plugin_unique_identifier": "p2", + "repo": "r", + "version": "v", + "package": "pkg", + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginService.upgrade_plugin_with_github", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: + def test_success(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + return_value=[1], + ), + ): + result = method(api) + + assert result["options"] == [1] + + def test_daemon_error(self, app): + api = PluginFetchDynamicSelectOptionsWithCredentialsApi() + method = unwrap(api.post) + + user = MagicMock(id="u1", is_admin_or_owner=True) + + payload = { + "plugin_id": "p", + "provider": "x", + "action": "y", + "parameter": "z", + "credential_id": "c", + "credentials": {"k": "v"}, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.plugin.PluginParameterService.get_dynamic_select_options_with_credentials", + side_effect=PluginDaemonClientSideError("error"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestPluginChangePreferencesApi: + def test_success(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=True), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.change_strategy", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_permission_fail(self, app): + api = PluginChangePreferencesApi() + method = unwrap(api.post) + + user = MagicMock(is_admin_or_owner=True) + + payload = { + "permission": { + "install_permission": TenantPluginPermission.InstallPermission.EVERYONE, + "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE, + }, + "auto_upgrade": { + "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + "upgrade_time_of_day": 0, + "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + "exclude_plugins": [], + "include_plugins": [], + }, + } + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.plugin.PluginPermissionService.change_permission", return_value=False), + ): + result = method(api) + + assert result["success"] is False + + +class TestPluginFetchPreferencesApi: + def test_success(self, app): + api = PluginFetchPreferencesApi() + method = unwrap(api.get) + + permission = MagicMock( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + + auto_upgrade = MagicMock( + strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY, + upgrade_time_of_day=1, + upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE, + exclude_plugins=[], + include_plugins=[], + ) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch( + "controllers.console.workspace.plugin.PluginPermissionService.get_permission", return_value=permission + ), + patch( + "controllers.console.workspace.plugin.PluginAutoUpgradeService.get_strategy", return_value=auto_upgrade + ), + ): + result = method(api) + + assert "permission" in result + assert "auto_upgrade" in result + + +class TestPluginAutoUpgradeExcludePluginApi: + def test_success(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=True), + ): + result = method(api) + + assert result["success"] is True + + def test_fail(self, app): + api = PluginAutoUpgradeExcludePluginApi() + method = unwrap(api.post) + + payload = {"plugin_id": "p"} + + with ( + app.test_request_context("/", json=payload), + patch("controllers.console.workspace.plugin.current_account_with_tenant", return_value=(None, "t1")), + patch("controllers.console.workspace.plugin.PluginAutoUpgradeService.exclude_plugin", return_value=False), + ): + result = method(api) + + assert result["success"] is False diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py index b15676d9b7..16ea1bf509 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_provider.py @@ -4,16 +4,52 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask from flask_restx import Api +from werkzeug.exceptions import Forbidden -from controllers.console.workspace.tool_providers import ToolProviderMCPApi +from controllers.console.workspace.tool_providers import ( + ToolApiListApi, + ToolApiProviderAddApi, + ToolApiProviderDeleteApi, + ToolApiProviderGetApi, + ToolApiProviderGetRemoteSchemaApi, + ToolApiProviderListToolsApi, + ToolApiProviderUpdateApi, + ToolBuiltinListApi, + ToolBuiltinProviderAddApi, + ToolBuiltinProviderCredentialsSchemaApi, + ToolBuiltinProviderDeleteApi, + ToolBuiltinProviderGetCredentialInfoApi, + ToolBuiltinProviderGetCredentialsApi, + ToolBuiltinProviderGetOauthClientSchemaApi, + ToolBuiltinProviderIconApi, + ToolBuiltinProviderInfoApi, + ToolBuiltinProviderListToolsApi, + ToolBuiltinProviderSetDefaultApi, + ToolBuiltinProviderUpdateApi, + ToolLabelsApi, + ToolOAuthCallback, + ToolOAuthCustomClient, + ToolPluginOAuthApi, + ToolProviderListApi, + ToolProviderMCPApi, + ToolWorkflowListApi, + ToolWorkflowProviderCreateApi, + ToolWorkflowProviderDeleteApi, + ToolWorkflowProviderGetApi, + ToolWorkflowProviderUpdateApi, + is_valid_url, +) from core.db.session_factory import configure_session_factory from extensions.ext_database import db from services.tools.mcp_tools_manage_service import ReconnectResult -# Backward-compat fixtures referenced by @pytest.mark.usefixtures in this file. -# They are intentionally no-ops because the test already patches the required -# behaviors explicitly via @patch and context managers below. +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + @pytest.fixture def _mock_cache(): return @@ -107,3 +143,602 @@ def test_create_mcp_provider_populates_tools(mock_reconnect, mock_session, mock_ # 若 transform 后包含 tools 字段,确保非空 assert isinstance(body.get("tools"), list) assert body["tools"] + + +class TestUtils: + def test_is_valid_url(self): + assert is_valid_url("https://example.com") + assert is_valid_url("http://example.com") + assert not is_valid_url("") + assert not is_valid_url("ftp://example.com") + assert not is_valid_url("not-a-url") + assert not is_valid_url(None) + + +class TestToolProviderListApi: + def test_get_success(self, app): + api = ToolProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u1"), "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.ToolCommonService.list_tool_providers", + return_value=["p1"], + ), + ): + assert method(api) == ["p1"] + + +class TestBuiltinProviderApis: + def test_list_tools(self, app): + api = ToolBuiltinProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tool_provider_tools", + return_value=[{"a": 1}], + ), + ): + assert method(api, "provider") == [{"a": 1}] + + def test_info(self, app): + api = ToolBuiltinProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_info", + return_value={"x": 1}, + ), + ): + assert method(api, "provider") == {"x": 1} + + def test_delete(self, app): + api = ToolBuiltinProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_id": "cid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t1"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_builtin_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api, "provider")["result"] == "success" + + def test_add_invalid_type(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}, "type": "invalid"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api, "provider") + + def test_add_success(self, app): + api = ToolBuiltinProviderAddApi() + method = unwrap(api.post) + + payload = {"credentials": {}, "type": "oauth2", "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.add_builtin_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api, "provider")["id"] == 1 + + def test_update(self, app): + api = ToolBuiltinProviderUpdateApi() + method = unwrap(api.post) + + payload = {"credential_id": "c1", "credentials": {}, "name": "n"} + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.update_builtin_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credentials(self, app): + api = ToolBuiltinProviderGetCredentialsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credentials", + return_value={"k": "v"}, + ), + ): + assert method(api, "provider") == {"k": "v"} + + def test_icon(self, app): + api = ToolBuiltinProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_icon", + return_value=(b"x", "image/png"), + ), + ): + response = method(api, "provider") + assert response.mimetype == "image/png" + + def test_credentials_schema(self, app): + api = ToolBuiltinProviderCredentialsSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_provider_credentials_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider", "oauth2") == {"schema": {}} + + def test_set_default_credential(self, app): + api = ToolBuiltinProviderSetDefaultApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"id": "c1"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.set_default_provider", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_credential_info(self, app): + api = ToolBuiltinProviderGetCredentialInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_credential_info", + return_value={"info": "x"}, + ), + ): + assert method(api, "provider") == {"info": "x"} + + def test_get_oauth_client_schema(self, app): + api = ToolBuiltinProviderGetOauthClientSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema", + return_value={"schema": {}}, + ), + ): + assert method(api, "provider") == {"schema": {}} + + +class TestApiProviderApis: + def test_add(self, app): + api = ToolApiProviderAddApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.create_api_tool_provider", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_remote_schema(self, app): + api = ToolApiProviderGetRemoteSchemaApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?url=http://x.com"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider_remote_schema", + return_value={"schema": "x"}, + ), + ): + assert method(api)["schema"] == "x" + + def test_list_tools(self, app): + api = ToolApiProviderListToolsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tool_provider_tools", + return_value=[{"tool": 1}], + ), + ): + assert method(api) == [{"tool": 1}] + + def test_update(self, app): + api = ToolApiProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "credentials": {}, + "schema_type": "openapi", + "schema": "{}", + "provider": "p", + "original_provider": "o", + "icon": {}, + "privacy_policy": "", + "custom_disclaimer": "", + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.update_api_tool_provider", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_delete(self, app): + api = ToolApiProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"provider": "p"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.delete_api_tool_provider", + return_value={"result": "success"}, + ), + ): + assert method(api)["result"] == "success" + + def test_get(self, app): + api = ToolApiProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/?provider=p"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.get_api_tool_provider", + return_value={"x": 1}, + ), + ): + assert method(api) == {"x": 1} + + +class TestWorkflowApis: + def test_create(self, app): + api = ToolWorkflowProviderCreateApi() + method = unwrap(api.post) + + payload = { + "workflow_app_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "n", + "label": "l", + "description": "d", + "icon": {}, + "parameters": [], + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.create_workflow_tool", + return_value={"id": 1}, + ), + ): + assert method(api)["id"] == 1 + + def test_update_invalid(self, app): + api = ToolWorkflowProviderUpdateApi() + method = unwrap(api.post) + + payload = { + "workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000", + "name": "Tool", + "label": "Tool Label", + "description": "A tool", + "icon": {}, + } + + with ( + app.test_request_context("/", json=payload), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.update_workflow_tool", + return_value={"ok": True}, + ), + ): + result = method(api) + assert result["ok"] + + def test_delete(self, app): + api = ToolWorkflowProviderDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"workflow_tool_id": "123e4567-e89b-12d3-a456-426614174000"}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.delete_workflow_tool", + return_value={"ok": True}, + ), + ): + assert method(api)["ok"] + + def test_get_error(self, app): + api = ToolWorkflowProviderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestLists: + def test_builtin_list(self, app): + api = ToolBuiltinListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.list_builtin_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_api_list(self, app): + api = ToolApiListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(None, "t"), + ), + patch( + "controllers.console.workspace.tool_providers.ApiToolManageService.list_api_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + def test_workflow_list(self, app): + api = ToolWorkflowListApi() + method = unwrap(api.get) + + m = MagicMock() + m.to_dict.return_value = {"x": 1} + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.WorkflowToolManageService.list_tenant_workflow_tools", + return_value=[m], + ), + ): + assert method(api) == [{"x": 1}] + + +class TestLabels: + def test_labels(self, app): + api = ToolLabelsApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.ToolLabelsService.list_tool_labels", + return_value=["l1"], + ), + ): + assert method(api) == ["l1"] + + +class TestOAuth: + def test_oauth_no_client(self, app): + api = ToolPluginOAuthApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(id="u"), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "provider") + + def test_oauth_callback_no_cookie(self, app): + api = ToolOAuthCallback() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "provider") + + +class TestOAuthCustomClient: + def test_save_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"client_params": {"a": 1}}), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] + + def test_get_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.get_custom_oauth_client_params", + return_value={"client_id": "x"}, + ), + ): + assert method(api, "provider") == {"client_id": "x"} + + def test_delete_custom_client(self, app): + api = ToolOAuthCustomClient() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.tool_providers.current_account_with_tenant", + return_value=(MagicMock(), "t"), + ), + patch( + "controllers.console.workspace.tool_providers.BuiltinToolManageService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "provider")["ok"] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py new file mode 100644 index 0000000000..4776bc7af0 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_trigger_providers.py @@ -0,0 +1,558 @@ +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import BadRequest, Forbidden + +from controllers.console.workspace.trigger_providers import ( + TriggerOAuthAuthorizeApi, + TriggerOAuthCallbackApi, + TriggerOAuthClientManageApi, + TriggerProviderIconApi, + TriggerProviderInfoApi, + TriggerProviderListApi, + TriggerSubscriptionBuilderBuildApi, + TriggerSubscriptionBuilderCreateApi, + TriggerSubscriptionBuilderGetApi, + TriggerSubscriptionBuilderLogsApi, + TriggerSubscriptionBuilderUpdateApi, + TriggerSubscriptionBuilderVerifyApi, + TriggerSubscriptionDeleteApi, + TriggerSubscriptionListApi, + TriggerSubscriptionUpdateApi, + TriggerSubscriptionVerifyApi, +) +from controllers.web.error import NotFoundError +from core.plugin.entities.plugin_daemon import CredentialType +from models.account import Account + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +def mock_user(): + user = MagicMock(spec=Account) + user.id = "u1" + user.current_tenant_id = "t1" + return user + + +class TestTriggerProviderApis: + def test_icon_success(self, app): + api = TriggerProviderIconApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_plugin_icon", + return_value="icon", + ), + ): + assert method(api, "github") == "icon" + + def test_list_providers(self, app): + api = TriggerProviderListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_providers", + return_value=[], + ), + ): + assert method(api) == [] + + def test_provider_info(self, app): + api = TriggerProviderInfoApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_trigger_provider", + return_value={"id": "p1"}, + ), + ): + assert method(api, "github") == {"id": "p1"} + + +class TestTriggerSubscriptionListApi: + def test_list_success(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + return_value=[], + ), + ): + assert method(api, "github") == [] + + def test_list_invalid_provider(self, app): + api = TriggerSubscriptionListApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.list_trigger_provider_subscriptions", + side_effect=ValueError("bad"), + ), + ): + result, status = method(api, "bad") + assert status == 404 + + +class TestTriggerSubscriptionBuilderApis: + def test_create_builder(self, app): + api = TriggerSubscriptionBuilderCreateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credential_type": "UNAUTHORIZED"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + result = method(api, "github") + assert "subscription_builder" in result + + def test_get_builder(self, app): + api = TriggerSubscriptionBuilderGetApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.get_subscription_builder_by_id", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_verify_builder(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {"a": 1}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "b1") == {"ok": True} + + def test_verify_builder_error(self, app): + api = TriggerSubscriptionBuilderVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_verify_builder", + side_effect=Exception("err"), + ), + ): + with pytest.raises(ValueError): + method(api, "github", "b1") + + def test_update_builder(self, app): + api = TriggerSubscriptionBuilderUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "n"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder", + return_value={"id": "b1"}, + ), + ): + assert method(api, "github", "b1") == {"id": "b1"} + + def test_logs(self, app): + api = TriggerSubscriptionBuilderLogsApi() + method = unwrap(api.get) + + log = MagicMock() + log.model_dump.return_value = {"a": 1} + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.list_logs", + return_value=[log], + ), + ): + assert "logs" in method(api, "github", "b1") + + def test_build(self, app): + api = TriggerSubscriptionBuilderBuildApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_and_build_builder", + return_value=None, + ), + ): + assert method(api, "github", "b1") == 200 + + +class TestTriggerSubscriptionCrud: + def test_update_rename_only(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.UNAUTHORIZED + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.update_trigger_subscription"), + ): + assert method(api, "s1") == 200 + + def test_update_not_found(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"name": "x"}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "x") + + def test_update_rebuild(self, app): + api = TriggerSubscriptionUpdateApi() + method = unwrap(api.post) + + sub = MagicMock() + sub.provider_id = "github" + sub.credential_type = CredentialType.OAUTH2 + sub.credentials = {} + sub.parameters = {} + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_subscription_by_id", + return_value=sub, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.rebuild_trigger_subscription" + ), + ): + assert method(api, "s1") == 200 + + def test_delete_subscription(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + mock_session = MagicMock() + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as mock_session_cls, + patch("controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider"), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionOperatorService.delete_plugin_trigger_by_subscription" + ), + ): + mock_db.engine = MagicMock() + mock_session_cls.return_value.__enter__.return_value = mock_session + + result = method(api, "sub1") + + assert result["result"] == "success" + + def test_delete_subscription_value_error(self, app): + api = TriggerSubscriptionDeleteApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch("controllers.console.workspace.trigger_providers.db") as mock_db, + patch("controllers.console.workspace.trigger_providers.Session") as session_cls, + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_trigger_provider", + side_effect=ValueError("bad"), + ), + ): + mock_db.engine = MagicMock() + session_cls.return_value.__enter__.return_value = MagicMock() + + with pytest.raises(BadRequest): + method(api, "sub1") + + +class TestTriggerOAuthApis: + def test_oauth_authorize_success(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.create_trigger_subscription_builder", + return_value=MagicMock(id="b1"), + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.create_proxy_context", + return_value="ctx", + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_authorization_url", + return_value=MagicMock(authorization_url="url"), + ), + ): + resp = method(api, "github") + assert resp.status_code == 200 + + def test_oauth_authorize_no_client(self, app): + api = TriggerOAuthAuthorizeApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(NotFoundError): + method(api, "github") + + def test_oauth_callback_forbidden(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + with app.test_request_context("/"): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_success(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", return_value=ctx + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials={"a": 1}, expires_at=1), + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerSubscriptionBuilderService.update_trigger_subscription_builder" + ), + ): + resp = method(api, "github") + assert resp.status_code == 302 + + def test_oauth_callback_no_oauth_client(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value=None, + ), + ): + with pytest.raises(Forbidden): + method(api, "github") + + def test_oauth_callback_empty_credentials(self, app): + api = TriggerOAuthCallbackApi() + method = unwrap(api.get) + + ctx = { + "user_id": "u1", + "tenant_id": "t1", + "subscription_builder_id": "b1", + } + + with ( + app.test_request_context("/", headers={"Cookie": "context_id=ctx"}), + patch( + "controllers.console.workspace.trigger_providers.OAuthProxyService.use_proxy_context", + return_value=ctx, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_oauth_client", + return_value={"a": 1}, + ), + patch( + "controllers.console.workspace.trigger_providers.OAuthHandler.get_credentials", + return_value=MagicMock(credentials=None, expires_at=None), + ), + ): + with pytest.raises(ValueError): + method(api, "github") + + +class TestTriggerOAuthClientManageApi: + def test_get_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.get_custom_oauth_client_params", + return_value={}, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_custom_client_enabled", + return_value=False, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.is_oauth_system_client_exists", + return_value=True, + ), + patch( + "controllers.console.workspace.trigger_providers.TriggerManager.get_trigger_provider", + return_value=MagicMock(get_oauth_client_schema=lambda: {}), + ), + ): + result = method(api, "github") + assert "configured" in result + + def test_post_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_delete_client(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.delete) + + with ( + app.test_request_context("/"), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.delete_custom_oauth_client_params", + return_value={"ok": True}, + ), + ): + assert method(api, "github") == {"ok": True} + + def test_oauth_client_post_value_error(self, app): + api = TriggerOAuthClientManageApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"enabled": True}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.save_custom_oauth_client_params", + side_effect=ValueError("bad"), + ), + ): + with pytest.raises(BadRequest): + method(api, "github") + + +class TestTriggerSubscriptionVerifyApi: + def test_verify_success(self, app): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + return_value={"ok": True}, + ), + ): + assert method(api, "github", "s1") == {"ok": True} + + @pytest.mark.parametrize("raised_exception", [ValueError("bad"), Exception("boom")]) + def test_verify_errors(self, app, raised_exception): + api = TriggerSubscriptionVerifyApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/", json={"credentials": {}}), + patch("controllers.console.workspace.trigger_providers.current_user", mock_user()), + patch( + "controllers.console.workspace.trigger_providers.TriggerProviderService.verify_subscription_credentials", + side_effect=raised_exception, + ), + ): + with pytest.raises(BadRequest): + method(api, "github", "s1") diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py new file mode 100644 index 0000000000..06f666fa60 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -0,0 +1,605 @@ +from datetime import datetime +from io import BytesIO +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.datastructures import FileStorage +from werkzeug.exceptions import Unauthorized + +import services +from controllers.common.errors import ( + FilenameNotExistsError, + FileTooLargeError, + NoFileUploadedError, + TooManyFilesError, + UnsupportedFileTypeError, +) +from controllers.console.error import AccountNotLinkTenantError +from controllers.console.workspace.workspace import ( + CustomConfigWorkspaceApi, + SwitchWorkspaceApi, + TenantApi, + TenantListApi, + WebappLogoWorkspaceApi, + WorkspaceInfoApi, + WorkspaceListApi, + WorkspacePermissionApi, +) +from enums.cloud_plan import CloudPlan +from models.account import TenantStatus + + +def unwrap(func): + while hasattr(func, "__wrapped__"): + func = func.__wrapped__ + return func + + +class TestTenantListApi: + def test_get_success(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant1 = MagicMock( + id="t1", + name="Tenant 1", + status="active", + created_at=datetime.utcnow(), + ) + tenant2 = MagicMock( + id="t2", + name="Tenant 2", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = True + features.billing.subscription.plan = CloudPlan.SANDBOX + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant1, tenant2], + ), + patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features), + ): + result, status = method(api) + + assert status == 200 + assert len(result["workspaces"]) == 2 + assert result["workspaces"][0]["current"] is True + + def test_get_billing_disabled(self, app): + api = TenantListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="Tenant", + status="active", + created_at=datetime.utcnow(), + ) + + features = MagicMock() + features.billing.enabled = False + + with ( + app.test_request_context("/workspaces"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.TenantService.get_join_tenants", + return_value=[tenant], + ), + patch( + "controllers.console.workspace.workspace.FeatureService.get_features", + return_value=features, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX + + +class TestWorkspaceListApi: + def test_get_success(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + + paginate_result = MagicMock( + items=[tenant], + has_next=False, + total=1, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 20}), + patch("controllers.console.workspace.workspace.db.paginate", return_value=paginate_result), + ): + result, status = method(api) + + assert status == 200 + assert result["total"] == 1 + assert result["has_more"] is False + + def test_get_has_next_true(self, app): + api = WorkspaceListApi() + method = unwrap(api.get) + + tenant = MagicMock( + id="t1", + name="T", + status="active", + created_at=datetime.utcnow(), + ) + + paginate_result = MagicMock( + items=[tenant], + has_next=True, + total=10, + ) + + with ( + app.test_request_context("/all-workspaces", query_string={"page": 1, "limit": 1}), + patch( + "controllers.console.workspace.workspace.db.paginate", + return_value=paginate_result, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["has_more"] is True + + +class TestTenantApi: + def test_post_active_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result, status = method(api) + + assert status == 200 + assert result["id"] == "t1" + + def test_post_archived_with_switch(self, app): + api = TenantApi() + method = unwrap(api.post) + + archived = MagicMock(status=TenantStatus.ARCHIVE) + new_tenant = MagicMock(status="active") + + user = MagicMock(current_tenant=archived) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[new_tenant]), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "new"} + ), + ): + result, status = method(api) + + assert result["id"] == "new" + + def test_post_archived_no_tenant(self, app): + api = TenantApi() + method = unwrap(api.post) + + user = MagicMock(current_tenant=MagicMock(status=TenantStatus.ARCHIVE)) + + with ( + app.test_request_context("/workspaces/current"), + patch("controllers.console.workspace.workspace.current_account_with_tenant", return_value=(user, "t1")), + patch("controllers.console.workspace.workspace.TenantService.get_join_tenants", return_value=[]), + ): + with pytest.raises(Unauthorized): + method(api) + + def test_post_info_path(self, app): + api = TenantApi() + method = unwrap(api.post) + + tenant = MagicMock(status="active") + user = MagicMock(current_tenant=tenant) + + with ( + app.test_request_context("/info"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(user, "t1"), + ), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + patch("controllers.console.workspace.workspace.logger.warning") as warn_mock, + ): + result, status = method(api) + + warn_mock.assert_called_once() + assert status == 200 + + +class TestSwitchWorkspaceApi: + def test_switch_success(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "t2"} + tenant = MagicMock(id="t2") + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t2"} + ), + ): + query_mock.return_value.get.return_value = tenant + result = method(api) + + assert result["result"] == "success" + + def test_switch_not_linked(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "bad"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant", side_effect=Exception), + ): + with pytest.raises(AccountNotLinkTenantError): + method(api) + + def test_switch_tenant_not_found(self, app): + api = SwitchWorkspaceApi() + method = unwrap(api.post) + + payload = {"tenant_id": "missing"} + + with ( + app.test_request_context("/workspaces/switch", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.TenantService.switch_tenant"), + patch("controllers.console.workspace.workspace.db.session.query") as query_mock, + ): + query_mock.return_value.get.return_value = None + + with pytest.raises(ValueError): + method(api) + + +class TestCustomConfigWorkspaceApi: + def test_post_success(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={}) + + payload = {"remove_webapp_brand": True} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", return_value={"id": "t1"} + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_logo_fallback(self, app): + api = CustomConfigWorkspaceApi() + method = unwrap(api.post) + + tenant = MagicMock(custom_config_dict={"replace_webapp_logo": "old-logo"}) + + payload = {"remove_webapp_brand": False} + + with ( + app.test_request_context("/workspaces/custom-config", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch( + "controllers.console.workspace.workspace.db.get_or_404", + return_value=tenant, + ), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"id": "t1"}, + ), + ): + result = method(api) + + assert tenant.custom_config_dict["replace_webapp_logo"] == "old-logo" + assert result["result"] == "success" + + +class TestWebappLogoWorkspaceApi: + def test_no_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + with ( + app.test_request_context("/upload", data={}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(NoFileUploadedError): + method(api) + + def test_too_many_files(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + data = { + "file": MagicMock(), + "extra": MagicMock(), + } + + with ( + app.test_request_context("/upload", data=data), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(TooManyFilesError): + method(api) + + def test_invalid_extension(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = MagicMock(filename="test.txt") + + with ( + app.test_request_context("/upload", data={"file": file}), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + ): + with pytest.raises(UnsupportedFileTypeError): + method(api) + + def test_upload_success(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="logo.png", + content_type="image/png", + ) + + upload = MagicMock(id="file1") + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.return_value = upload + + result, status = method(api) + + assert status == 201 + assert result["id"] == "file1" + + def test_filename_missing(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"data"), + filename="", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + ): + with pytest.raises(FilenameNotExistsError): + method(api) + + def test_file_too_large(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.FileTooLargeError("too big") + + with pytest.raises(FileTooLargeError): + method(api) + + def test_service_unsupported_file(self, app): + api = WebappLogoWorkspaceApi() + method = unwrap(api.post) + + file = FileStorage( + stream=BytesIO(b"x"), + filename="logo.png", + content_type="image/png", + ) + + with ( + app.test_request_context( + "/upload", + data={"file": file}, + content_type="multipart/form-data", + ), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), "t1"), + ), + patch("controllers.console.workspace.workspace.FileService") as fs, + patch("controllers.console.workspace.workspace.db") as mock_db, + ): + mock_db.engine = MagicMock() + fs.return_value.upload_file.side_effect = services.errors.file.UnsupportedFileTypeError() + + with pytest.raises(UnsupportedFileTypeError): + method(api) + + +class TestWorkspaceInfoApi: + def test_post_success(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + tenant = MagicMock() + + payload = {"name": "New Name"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch("controllers.console.workspace.workspace.db.get_or_404", return_value=tenant), + patch("controllers.console.workspace.workspace.db.session.commit"), + patch( + "controllers.console.workspace.workspace.WorkspaceService.get_tenant_info", + return_value={"name": "New Name"}, + ), + ): + result = method(api) + + assert result["result"] == "success" + + def test_no_current_tenant(self, app): + api = WorkspaceInfoApi() + method = unwrap(api.post) + + payload = {"name": "X"} + + with ( + app.test_request_context("/workspaces/info", json=payload), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) + + +class TestWorkspacePermissionApi: + def test_get_success(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + permission = MagicMock( + workspace_id="t1", + allow_member_invite=True, + allow_owner_transfer=False, + ) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1") + ), + patch( + "controllers.console.workspace.workspace.EnterpriseService.WorkspacePermissionService.get_permission", + return_value=permission, + ), + ): + result, status = method(api) + + assert status == 200 + assert result["workspace_id"] == "t1" + + def test_no_current_tenant(self, app): + api = WorkspacePermissionApi() + method = unwrap(api.get) + + with ( + app.test_request_context("/permission"), + patch( + "controllers.console.workspace.workspace.current_account_with_tenant", + return_value=(MagicMock(), None), + ), + ): + with pytest.raises(ValueError): + method(api) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py new file mode 100644 index 0000000000..b290748155 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace_wraps.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import importlib +from types import SimpleNamespace + +import pytest +from werkzeug.exceptions import Forbidden + +from controllers.console.workspace import plugin_permission_required +from models.account import TenantPluginPermission + + +class _SessionStub: + def __init__(self, permission): + self._permission = permission + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def query(self, *_args, **_kwargs): + return self + + def where(self, *_args, **_kwargs): + return self + + def first(self): + return self._permission + + +def _workspace_module(): + return importlib.import_module(plugin_permission_required.__module__) + + +def _patch_session(monkeypatch: pytest.MonkeyPatch, permission): + module = _workspace_module() + monkeypatch.setattr(module, "Session", lambda *_args, **_kwargs: _SessionStub(permission)) + monkeypatch.setattr(module, "db", SimpleNamespace(engine=object())) + + +def test_plugin_permission_allows_without_permission(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, None) + + @plugin_permission_required() + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_install_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.NOBODY, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_install_admin_allows_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.ADMINS, + debug_permission=TenantPluginPermission.DebugPermission.EVERYONE, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(install_required=True) + def handler(): + return "ok" + + assert handler() == "ok" + + +def test_plugin_permission_debug_nobody_forbidden(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=True) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.NOBODY, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() + + +def test_plugin_permission_debug_admin_requires_admin(monkeypatch: pytest.MonkeyPatch) -> None: + user = SimpleNamespace(is_admin_or_owner=False) + permission = SimpleNamespace( + install_permission=TenantPluginPermission.InstallPermission.EVERYONE, + debug_permission=TenantPluginPermission.DebugPermission.ADMINS, + ) + module = _workspace_module() + monkeypatch.setattr(module, "current_account_with_tenant", lambda: (user, "t1")) + _patch_session(monkeypatch, permission) + + @plugin_permission_required(debug_required=True) + def handler(): + return "ok" + + with pytest.raises(Forbidden): + handler() From c72ac8a4347f8f6319b2b8221adf85d727385bba Mon Sep 17 00:00:00 2001 From: Stephen Zhou Date: Mon, 9 Mar 2026 17:24:56 +0800 Subject: [PATCH 3/3] ci: ignore some major update (#33161) --- .github/dependabot.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 917e0f6b07..63b3f05dfa 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -24,6 +24,21 @@ updates: schedule: interval: "weekly" open-pull-requests-limit: 2 + ignore: + - dependency-name: "@sentry/react" + update-types: ["version-update:semver-major"] + - dependency-name: "tailwindcss" + update-types: ["version-update:semver-major"] + - dependency-name: "echarts" + update-types: ["version-update:semver-major"] + - dependency-name: "uuid" + update-types: ["version-update:semver-major"] + - dependency-name: "react-markdown" + update-types: ["version-update:semver-major"] + - dependency-name: "react-syntax-highlighter" + update-types: ["version-update:semver-major"] + - dependency-name: "react-window" + update-types: ["version-update:semver-major"] groups: lexical: patterns: @@ -33,6 +48,9 @@ updates: patterns: - "storybook" - "@storybook/*" + eslint-group: + patterns: + - "*eslint*" npm-dependencies: patterns: - "*" @@ -41,3 +59,4 @@ updates: - "@lexical/*" - "storybook" - "@storybook/*" + - "*eslint*"