import datetime import json from contextlib import ExitStack from unittest.mock import MagicMock, PropertyMock, patch import pytest from flask import Flask from werkzeug.exceptions import BadRequest, Forbidden, NotFound import services from controllers.console import console_ns from controllers.console.app.error import ProviderNotInitializeError from controllers.console.datasets.datasets import ( DatasetApi, DatasetApiBaseUrlApi, DatasetApiDeleteApi, DatasetApiKeyApi, DatasetAutoDisableLogApi, DatasetEnableApiApi, DatasetErrorDocs, DatasetIndexingEstimateApi, DatasetIndexingStatusApi, DatasetListApi, DatasetPermissionUserListApi, DatasetQueryApi, DatasetRelatedAppListApi, DatasetRetrievalSettingApi, DatasetRetrievalSettingMockApi, DatasetUseCheckApi, ) from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.provider_manager import ProviderManager from core.rag.index_processor.constant.index_type import IndexStructureType from extensions.storage.storage_type import StorageType from models.dataset import Dataset, DatasetQuery, Document from models.enums import CreatorUserRole, DataSourceType, DocumentCreatedFrom, IndexingStatus from models.model import ApiToken, App, AppMode, IconType, UploadFile from services.dataset_service import DatasetPermissionService, DatasetService def unwrap(func): while hasattr(func, "__wrapped__"): func = func.__wrapped__ return func @pytest.fixture(autouse=True) def dataset_model_property_defaults(): properties: dict[str, object] = { "app_count": 0, "document_count": 0, "word_count": 0, "author_name": None, "tags": [], "doc_form": None, "external_knowledge_info": None, "doc_metadata": [], "is_published": False, "total_documents": 0, "total_available_documents": 0, } with ExitStack() as stack: for name, value in properties.items(): property_mock = stack.enter_context(patch.object(Dataset, name, new_callable=PropertyMock)) property_mock.return_value = value yield def make_dataset(**overrides) -> Dataset: base = { "id": "ds-1", "tenant_id": "tenant-1", "name": "Dataset", "description": "desc", "provider": "vendor", "permission": "only_me", "data_source_type": None, "indexing_technique": "economy", "created_by": "account-1", "created_at": datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC), "updated_by": None, "updated_at": datetime.datetime(2024, 1, 1, 12, 0, 0, tzinfo=datetime.UTC), "embedding_model": None, "embedding_model_provider": None, "retrieval_model": None, "summary_index_setting": None, "built_in_field_enabled": False, "pipeline_id": None, "runtime_mode": "general", "chunk_structure": None, "icon_info": None, "enable_api": False, "is_multimodal": False, } base.update(overrides) return Dataset(**base) def make_related_app(**overrides) -> App: base = { "id": "app-1", "tenant_id": "tenant-1", "name": "App", "description": "desc", "mode": AppMode.CHAT, "icon_type": IconType.EMOJI, "icon": "🤖", "icon_background": "#fff", "app_model_config_id": None, "workflow_id": None, "enable_site": False, "enable_api": False, "created_by": "account-1", } base.update(overrides) return App(**base) def make_document_status(**overrides) -> Document: base = { "id": "doc-1", "tenant_id": "tenant-1", "dataset_id": "dataset-1", "position": 1, "data_source_type": DataSourceType.UPLOAD_FILE, "batch": "batch-1", "name": "doc.txt", "created_from": DocumentCreatedFrom.WEB, "created_by": "account-1", "indexing_status": IndexingStatus.COMPLETED, "enabled": True, "archived": False, "processing_started_at": None, "parsing_completed_at": None, "cleaning_completed_at": None, "splitting_completed_at": None, "completed_at": None, "paused_at": None, "error": None, "stopped_at": None, } base.update(overrides) return Document(**base) class TestDatasetList: def _mock_user(self): user = MagicMock() user.is_dataset_editor = True return user def test_get_success_basic(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) current_user = self._mock_user() datasets = [make_dataset(icon_info={"icon": "📙", "icon_type": "emoji"})] with app.test_request_context("/datasets"): with ( patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_datasets", return_value=(datasets, 1), ), patch.object( ProviderManager, "get_configurations", return_value=MagicMock(get_models=lambda **_: []), ), ): resp, status = method(api) assert status == 200 assert resp["total"] == 1 assert resp["data"][0]["embedding_available"] is True assert resp["data"][0]["icon_info"] == { "icon": "📙", "icon_background": None, "icon_type": "emoji", "icon_url": None, } def test_get_with_ids_filter(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) current_user = self._mock_user() datasets = [make_dataset()] with app.test_request_context("/datasets?ids=1&ids=2"): with ( patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_datasets_by_ids", return_value=(datasets, 2), ) as by_ids_mock, patch.object( ProviderManager, "get_configurations", return_value=MagicMock(get_models=lambda **_: []), ), ): resp, status = method(api) by_ids_mock.assert_called_once() assert status == 200 assert resp["total"] == 2 def test_get_with_tag_ids(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) current_user = self._mock_user() datasets = [make_dataset()] with app.test_request_context("/datasets?tag_ids=tag1"): with ( patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_datasets", return_value=(datasets, 1), ), patch.object( ProviderManager, "get_configurations", return_value=MagicMock(get_models=lambda **_: []), ), ): resp, status = method(api) assert status == 200 def test_embedding_available_false(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) current_user = self._mock_user() datasets = [ make_dataset( indexing_technique="high_quality", embedding_model="text-embed", embedding_model_provider="openai", ) ] config = MagicMock() config.get_models.return_value = [] # model not available with app.test_request_context("/datasets"): with ( patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_datasets", return_value=(datasets, 1), ), patch.object( ProviderManager, "get_configurations", return_value=config, ), ): resp, status = method(api) assert resp["data"][0]["embedding_available"] is False def test_partial_members_permission(self, app: Flask): api = DatasetListApi() method = unwrap(api.get) current_user = self._mock_user() datasets = [make_dataset(permission="partial_members")] with app.test_request_context("/datasets"): with ( patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_datasets", return_value=(datasets, 1), ), patch( "controllers.console.datasets.datasets.db.session.execute", return_value=MagicMock(all=lambda: [("ds-1", "u1")]), ), patch.object( ProviderManager, "get_configurations", return_value=MagicMock(get_models=lambda **_: []), ), ): resp, status = method(api) assert resp["data"][0]["partial_member_list"] == ["u1"] class TestDatasetListApiPost: def test_post_success(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) payload = { "name": "My Dataset", "description": "desc", "indexing_technique": "economy", "provider": "vendor", } user = MagicMock() user.is_dataset_editor = True dataset = make_dataset(name=payload["name"], description=payload["description"]) with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), patch.object( DatasetService, "create_empty_dataset", return_value=dataset, ), ): _, status = method(api) assert status == 201 def test_post_forbidden(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) payload = {"name": "test"} user = MagicMock() user.is_dataset_editor = False with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), ): with pytest.raises(Forbidden): method(api) def test_post_duplicate_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) payload = {"name": "duplicate"} user = MagicMock() user.is_dataset_editor = True with ( app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant-1"), ), patch.object( DatasetService, "create_empty_dataset", side_effect=services.errors.dataset.DatasetNameDuplicateError(), ), ): with pytest.raises(DatasetNameDuplicateError): method(api) def test_post_invalid_payload_missing_name(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}): with pytest.raises(ValueError): method(api) def test_post_invalid_indexing_technique(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) payload = { "name": "bad", "indexing_technique": "invalid-tech", } with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(ValueError, match="Invalid indexing technique"): method(api) def test_post_invalid_provider(self, app: Flask): api = DatasetListApi() method = unwrap(api.post) payload = { "name": "bad", "provider": "unknown", } with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload): with pytest.raises(ValueError, match="Invalid provider"): method(api) class TestDatasetApiGet: def test_get_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "123e4567-e89b-12d3-a456-426614174000" user = MagicMock() tenant_id = "tenant-1" dataset = make_dataset(id=dataset_id) with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, tenant_id), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] data, status = method(api, dataset_id) assert status == 200 assert data["embedding_available"] is True def test_get_uses_default_external_retrieval_model(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" dataset = make_dataset(id=dataset_id, retrieval_model=None) with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] data, status = method(api, dataset_id) assert status == 200 assert data["external_retrieval_model"] == { "top_k": 2, "score_threshold": 0.0, "score_threshold_enabled": None, } def test_get_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "missing-id" with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=None, ), ): with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) def test_get_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" dataset = MagicMock() with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", side_effect=services.errors.account.NoPermissionError("no access"), ), ): with pytest.raises(Forbidden, match="no access"): method(api, dataset_id) def test_get_high_quality_embedding_unavailable(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" user = MagicMock() tenant_id = "tenant-1" dataset = make_dataset( id=dataset_id, indexing_technique="high_quality", embedding_model="text-embedding", embedding_model_provider="openai", ) with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, tenant_id), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] data, _ = method(api, dataset_id) assert data["embedding_available"] is False def test_get_partial_members_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.get) dataset_id = "dataset-id" dataset = make_dataset(id=dataset_id, permission="partial_members") partial_members = ["u1", "u2"] with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch.object( DatasetPermissionService, "get_dataset_partial_member_list", return_value=partial_members, ), patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] data, _ = method(api, dataset_id) assert data["partial_member_list"] == partial_members class TestDatasetApiPatch: def test_patch_success_basic(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) dataset_id = "dataset-id" payload = { "name": "updated-name", "description": "updated description", } user = MagicMock() tenant_id = "tenant-1" dataset = make_dataset(id=dataset_id, tenant_id=tenant_id) with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, tenant_id), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "check_permission", return_value=None, ), patch.object( DatasetService, "update_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "get_dataset_partial_member_list", return_value=[], ), ): result, status = method(api, dataset_id) assert status == 200 assert result["partial_member_list"] == [] def test_patch_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) with ( app.test_request_context("/datasets/missing"), patch.object( DatasetService, "get_dataset", return_value=None, ), ): with pytest.raises(NotFound, match="Dataset not found"): method(api, "missing") def test_patch_permission_denied(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) dataset_id = "dataset-id" dataset = MagicMock() payload = {"name": "x"} with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetPermissionService, "check_permission", side_effect=Forbidden("no permission"), ), ): with pytest.raises(Forbidden): method(api, dataset_id) def test_patch_partial_members_update(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) dataset_id = "dataset-id" payload = { "permission": "partial_members", "partial_member_list": [{"user_id": "u1"}, {"user_id": "u2"}], } dataset = make_dataset(id=dataset_id, permission="partial_members") with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "check_permission", return_value=None, ), patch.object( DatasetService, "update_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "update_partial_member_list", return_value=None, ), patch.object( DatasetPermissionService, "get_dataset_partial_member_list", return_value=["u1", "u2"], ), ): result, _ = method(api, dataset_id) assert result["partial_member_list"] == ["u1", "u2"] def test_patch_clear_partial_members(self, app: Flask): api = DatasetApi() method = unwrap(api.patch) dataset_id = "dataset-id" payload = { "permission": "only_me", } dataset = make_dataset(id=dataset_id) with ( app.test_request_context(f"/datasets/{dataset_id}"), patch.object(type(console_ns), "payload", payload), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "check_permission", return_value=None, ), patch.object( DatasetService, "update_dataset", return_value=dataset, ), patch.object( DatasetPermissionService, "clear_partial_member_list", return_value=None, ), patch.object( DatasetPermissionService, "get_dataset_partial_member_list", return_value=[], ), ): result, _ = method(api, dataset_id) assert result["partial_member_list"] == [] class TestDatasetApiDelete: def test_delete_success(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "dataset-id" user = MagicMock() user.has_edit_permission = True user.is_dataset_operator = False with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant"), ), patch.object( DatasetService, "delete_dataset", return_value=True, ), patch.object( DatasetPermissionService, "clear_partial_member_list", return_value=None, ), ): result, status = method(api, dataset_id) assert status == 204 assert result == "" def test_delete_forbidden_no_permission(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "dataset-id" user = MagicMock() user.has_edit_permission = False user.is_dataset_operator = False with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant"), ), ): with pytest.raises(Forbidden): method(api, dataset_id) def test_delete_dataset_not_found(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "missing-dataset" user = MagicMock() user.has_edit_permission = True user.is_dataset_operator = False with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant"), ), patch.object( DatasetService, "delete_dataset", return_value=False, ), ): with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) def test_delete_dataset_in_use(self, app: Flask): api = DatasetApi() method = unwrap(api.delete) dataset_id = "dataset-id" user = MagicMock() user.has_edit_permission = True user.is_dataset_operator = False with ( app.test_request_context(f"/datasets/{dataset_id}"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(user, "tenant"), ), patch.object( DatasetService, "delete_dataset", side_effect=services.errors.dataset.DatasetInUseError(), ), ): with pytest.raises(DatasetInUseError): method(api, dataset_id) class TestDatasetUseCheckApi: def test_get_use_check_true(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) dataset_id = "dataset-id" with ( app.test_request_context(f"/datasets/{dataset_id}/use-check"), patch.object( DatasetService, "dataset_use_check", return_value=True, ), ): result, status = method(api, dataset_id) assert status == 200 assert result == {"is_using": True} def test_get_use_check_false(self, app: Flask): api = DatasetUseCheckApi() method = unwrap(api.get) dataset_id = "dataset-id" with ( app.test_request_context(f"/datasets/{dataset_id}/use-check"), patch.object( DatasetService, "dataset_use_check", return_value=False, ), ): result, status = method(api, dataset_id) assert status == 200 assert result == {"is_using": False} class TestDatasetQueryApi: def _query_record(self, index: int = 1) -> DatasetQuery: query = DatasetQuery( dataset_id="dataset-id", content=json.dumps( [ { "content_type": "text_query", "content": f"question {index}", "file_info": None, } ] ), source="hit_testing", source_app_id=None, created_by_role=CreatorUserRole.ACCOUNT, created_by=f"account-{index}", ) query.id = f"query-{index}" query.created_at = datetime.datetime(2024, 1, index, 12, 0, 0, tzinfo=datetime.UTC) return query def test_get_queries_success(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" current_user = MagicMock() dataset = MagicMock() dataset.id = dataset_id queries = [self._query_record(1), self._query_record(2)] with ( app.test_request_context("/datasets/queries?page=1&limit=20"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch.object( DatasetService, "get_dataset_queries", return_value=(queries, 2), ), ): response, status = method(api, dataset_id) assert status == 200 assert response["total"] == 2 assert response["page"] == 1 assert response["limit"] == 20 assert response["has_more"] is False assert len(response["data"]) == 2 assert response["data"][0] == { "id": "query-1", "queries": [ { "content_type": "text_query", "content": "question 1", "file_info": None, } ], "source": "hit_testing", "source_app_id": None, "created_by_role": "account", "created_by": "account-1", "created_at": 1704110400, } def test_get_queries_dataset_not_found(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" current_user = MagicMock() with ( app.test_request_context("/datasets/queries"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_dataset", return_value=None, ), ): with pytest.raises(NotFound, match="Dataset not found"): method(api, dataset_id) def test_get_queries_permission_denied(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" current_user = MagicMock() dataset = MagicMock() with ( app.test_request_context("/datasets/queries"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", side_effect=services.errors.account.NoPermissionError("no access"), ), ): with pytest.raises(Forbidden): method(api, dataset_id) def test_get_queries_pagination_has_more(self, app: Flask): api = DatasetQueryApi() method = unwrap(api.get) dataset_id = "dataset-id" current_user = MagicMock() dataset = MagicMock() dataset.id = dataset_id queries = [self._query_record(index) for index in range(1, 21)] with ( app.test_request_context("/datasets/queries?page=1&limit=20"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(current_user, "tenant-1"), ), patch.object( DatasetService, "get_dataset", return_value=dataset, ), patch.object( DatasetService, "check_dataset_permission", return_value=None, ), patch.object( DatasetService, "get_dataset_queries", return_value=(queries, 40), ), ): response, status = method(api, dataset_id) assert status == 200 assert response["has_more"] is True assert len(response["data"]) == 20 class TestDatasetIndexingEstimateApi: def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile: upload_file = UploadFile( tenant_id=tenant_id, storage_type=StorageType.LOCAL, key="key", name="name.txt", size=1, extension="txt", mime_type="text/plain", created_by_role=CreatorUserRole.ACCOUNT, created_by="user-1", created_at=datetime.datetime.now(tz=datetime.UTC), used=False, ) upload_file.id = file_id return upload_file def _base_payload(self): return { "info_list": { "data_source_type": "upload_file", "file_info_list": { "file_ids": ["file-1"], }, }, "process_rule": {"chunk_size": 100}, "indexing_technique": "high_quality", "doc_form": IndexStructureType.PARAGRAPH_INDEX, "doc_language": "English", "dataset_id": None, } def test_post_success_upload_file(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) payload = self._base_payload() mock_file = self._upload_file() mock_response = MagicMock() mock_response.model_dump.return_value = {"tokens": 100} with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch.object( type(console_ns), "payload", new_callable=PropertyMock, return_value=payload, ), patch( "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_file]), ), patch( "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", return_value=mock_response, ), ): response, status = method(api) assert status == 200 assert response == {"tokens": 100} def test_post_file_not_found(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) payload = self._base_payload() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch.object( type(console_ns), "payload", new_callable=PropertyMock, return_value=payload, ), patch( "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: None), ), ): with pytest.raises(NotFound): method(api) def test_post_llm_bad_request_error(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() payload = self._base_payload() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch.object( type(console_ns), "payload", new_callable=PropertyMock, return_value=payload, ), patch( "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_file]), ), patch( "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", side_effect=LLMBadRequestError(), ), ): with pytest.raises(ProviderNotInitializeError): method(api) def test_post_provider_token_not_init(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() payload = self._base_payload() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch.object( type(console_ns), "payload", new_callable=PropertyMock, return_value=payload, ), patch( "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_file]), ), patch( "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", side_effect=ProviderTokenNotInitError("token missing"), ), ): with pytest.raises(ProviderNotInitializeError): method(api) def test_post_generic_exception(self, app: Flask): api = DatasetIndexingEstimateApi() method = unwrap(api.post) mock_file = self._upload_file() payload = self._base_payload() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch.object( type(console_ns), "payload", new_callable=PropertyMock, return_value=payload, ), patch( "controllers.console.datasets.datasets.DocumentService.estimate_args_validate", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_file]), ), patch( "controllers.console.datasets.datasets.IndexingRunner.indexing_estimate", side_effect=Exception("boom"), ), ): with pytest.raises(IndexingEstimateError): method(api) class TestDatasetRelatedAppListApi: def test_get_success(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) dataset = MagicMock() dataset.id = "dataset-1" app1 = make_related_app(id="app-1", name="App 1") app2 = make_related_app(id="app-2", name="App 2") join1 = MagicMock(app=app1) join2 = MagicMock(app=app2) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", return_value=None, ), patch( "controllers.console.datasets.datasets.DatasetService.get_related_apps", return_value=[join1, join2], ), ): response, status = method(api, "dataset-1") assert status == 200 assert response["total"] == 2 assert response["data"] == [ { "id": "app-1", "name": "App 1", "description": "desc", "mode": "chat", "icon_type": "emoji", "icon": "🤖", "icon_background": "#fff", "icon_url": None, }, { "id": "app-2", "name": "App 2", "description": "desc", "mode": "chat", "icon_type": "emoji", "icon": "🤖", "icon_background": "#fff", "icon_url": None, }, ] def test_get_dataset_not_found(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): method(api, "dataset-1") def test_get_permission_denied(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) dataset = MagicMock() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", side_effect=services.errors.account.NoPermissionError("no permission"), ), ): with pytest.raises(Forbidden): method(api, "dataset-1") def test_get_filters_none_apps(self, app: Flask): api = DatasetRelatedAppListApi() method = unwrap(api.get) dataset = MagicMock() dataset.id = "dataset-1" app1 = make_related_app() join1 = MagicMock(app=app1) join2 = MagicMock(app=None) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", return_value=None, ), patch( "controllers.console.datasets.datasets.DatasetService.get_related_apps", return_value=[join1, join2], ), ): response, status = method(api, "dataset-1") assert status == 200 assert response["total"] == 1 assert response["data"] == [ { "id": "app-1", "name": "App", "description": "desc", "mode": "chat", "icon_type": "emoji", "icon": "🤖", "icon_background": "#fff", "icon_url": None, } ] class TestDatasetIndexingStatusApi: def test_get_success_with_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) document = MagicMock() document.id = "doc-1" document.indexing_status = "completed" document.processing_started_at = None document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None document.completed_at = None document.paused_at = None document.error = None document.stopped_at = None with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [document]), ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=3, ), ): response, status = method(api, "dataset-1") assert status == 200 assert "data" in response assert len(response["data"]) == 1 item = response["data"][0] assert item["completed_segments"] == 3 assert item["total_segments"] == 3 def test_get_success_no_documents(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: []), ), ): response, status = method(api, "dataset-1") assert status == 200 assert response == {"data": []} def test_segment_counts_different_values(self, app: Flask): api = DatasetIndexingStatusApi() method = unwrap(api.get) document = MagicMock() document.id = "doc-1" document.indexing_status = "indexing" document.processing_started_at = None document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None document.completed_at = None document.paused_at = None document.error = None document.stopped_at = None with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [document]), ), patch( "controllers.console.datasets.datasets.db.session.scalar", side_effect=[2, 5], ), ): response, status = method(api, "dataset-1") assert status == 200 item = response["data"][0] assert item["completed_segments"] == 2 assert item["total_segments"] == 5 class TestDatasetApiKeyApi: def test_get_api_keys_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.get) mock_key_1 = MagicMock(spec=ApiToken) mock_key_1.id = "key-1" mock_key_1.type = "dataset" mock_key_1.token = "ds-abc" mock_key_1.last_used_at = None mock_key_1.created_at = None mock_key_2 = MagicMock(spec=ApiToken) mock_key_2.id = "key-2" mock_key_2.type = "dataset" mock_key_2.token = "ds-def" mock_key_2.last_used_at = None mock_key_2.created_at = None with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalars", return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]), ), ): response = method(api) assert "data" in response assert len(response["data"]) == 2 assert response["data"][0]["id"] == "key-1" assert response["data"][0]["token"] == "ds-abc" assert response["data"][1]["id"] == "key-2" assert response["data"][1]["token"] == "ds-def" def test_post_create_api_key_success(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) mock_token = MagicMock() mock_token.id = "new-key-id" mock_token.last_used_at = None mock_token.created_at = datetime.datetime(2024, 1, 1, 0, 0, 0, tzinfo=datetime.UTC) mock_api_token_cls = MagicMock() mock_api_token_cls.return_value = mock_token mock_api_token_cls.generate_api_key.return_value = "dataset-abc123" with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=3, ), patch( "controllers.console.datasets.datasets.ApiToken", mock_api_token_cls, ), patch( "controllers.console.datasets.datasets.db.session.add", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.commit", return_value=None, ), ): response, status = method(api) assert status == 200 assert isinstance(response, dict) assert response["id"] == "new-key-id" assert response["token"] == "dataset-abc123" assert response["type"] == "dataset" assert response["created_at"] is not None def test_post_exceed_max_keys(self, app: Flask): api = DatasetApiKeyApi() method = unwrap(api.post) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=10, ), ): with pytest.raises(BadRequest) as exc_info: method(api) assert exc_info.value.code == 400 assert vars(exc_info.value)["data"] == { "message": "Cannot create more than 10 API keys for this resource type.", "custom": "max_keys_exceeded", } class TestDatasetApiDeleteApi: def test_delete_success(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) mock_key = MagicMock() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=mock_key, ), patch( "controllers.console.datasets.datasets.db.session.commit", return_value=None, ), patch( "controllers.console.datasets.datasets.db.session.delete", return_value=None, ), ): response, status = method(api, "api-key-id") assert status == 204 assert response == "" def test_delete_key_not_found(self, app: Flask): api = DatasetApiDeleteApi() method = unwrap(api.delete) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.db.session.scalar", return_value=None, ), ): with pytest.raises(NotFound): method(api, "api-key-id") class TestDatasetEnableApiApi: def test_enable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", return_value=None, ), ): response, status = method(api, "dataset-1", "enable") assert status == 200 assert response["result"] == "success" def test_disable_api(self, app: Flask): api = DatasetEnableApiApi() method = unwrap(api.post) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.update_dataset_api_status", return_value=None, ), ): response, status = method(api, "dataset-1", "disable") assert status == 200 assert response["result"] == "success" class TestDatasetApiBaseUrlApi: def test_get_api_base_url_from_config(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", "https://example.com", ), ): response = method(api) assert response["api_base_url"] == "https://example.com/v1" def test_get_api_base_url_from_request(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) with ( app.test_request_context("http://localhost:5000/"), patch( "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", None, ), ): response = method(api) assert response["api_base_url"] == "http://localhost:5000/v1" def test_get_api_base_url_no_double_v1(self, app: Flask): api = DatasetApiBaseUrlApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.dify_config.SERVICE_API_URL", "https://example.com/v1", ), ): response = method(api) assert response["api_base_url"] == "https://example.com/v1" class TestDatasetRetrievalSettingApi: def test_get_success(self, app: Flask): api = DatasetRetrievalSettingApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.dify_config.VECTOR_STORE", "qdrant", ), patch( "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", return_value={"retrieval_method": ["semantic", "hybrid"]}, ), ): response = method(api) assert "retrieval_method" in response class TestDatasetRetrievalSettingMockApi: def test_get_success(self, app: Flask): api = DatasetRetrievalSettingMockApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type", return_value={"retrieval_method": ["semantic"]}, ), ): response = method(api, "milvus") assert response["retrieval_method"] == ["semantic"] class TestDatasetErrorDocs: def test_get_success(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) dataset = MagicMock() error_doc = make_document_status(id="error-doc", indexing_status=IndexingStatus.ERROR, error="failed") with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id", return_value=[error_doc], ), ): response, status = method(api, "dataset-1") assert status == 200 assert response["total"] == 1 def test_get_dataset_not_found(self, app: Flask): api = DatasetErrorDocs() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): method(api, "dataset-1") class TestDatasetPermissionUserListApi: def test_get_success(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) dataset = MagicMock() users = ["u1", "u2"] with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", return_value=None, ), patch( "controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list", return_value=users, ), ): response, status = method(api, "dataset-1") assert status == 200 assert response["data"] == users def test_get_permission_denied(self, app: Flask): api = DatasetPermissionUserListApi() method = unwrap(api.get) dataset = MagicMock() with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.current_account_with_tenant", return_value=(MagicMock(), "tenant-1"), ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.check_dataset_permission", side_effect=services.errors.account.NoPermissionError("no permission"), ), ): with pytest.raises(Forbidden): method(api, "dataset-1") class TestDatasetAutoDisableLogApi: def test_get_success(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) dataset = MagicMock() logs = {"document_ids": ["doc-1"], "count": 1} with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=dataset, ), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs", return_value=logs, ), ): response, status = method(api, "dataset-1") assert status == 200 assert response == logs def test_get_dataset_not_found(self, app: Flask): api = DatasetAutoDisableLogApi() method = unwrap(api.get) with ( app.test_request_context("/"), patch( "controllers.console.datasets.datasets.DatasetService.get_dataset", return_value=None, ), ): with pytest.raises(NotFound): method(api, "dataset-1")