Files
dify/api/tests/unit_tests/controllers/console/datasets/test_datasets.py
2026-03-09 17:07:13 +08:00

1927 lines
62 KiB
Python

import datetime
from unittest.mock import MagicMock, PropertyMock, patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
import services
from controllers.console import console_ns
from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.datasets import (
DatasetApi,
DatasetApiBaseUrlApi,
DatasetApiDeleteApi,
DatasetApiKeyApi,
DatasetAutoDisableLogApi,
DatasetEnableApiApi,
DatasetErrorDocs,
DatasetIndexingEstimateApi,
DatasetIndexingStatusApi,
DatasetListApi,
DatasetPermissionUserListApi,
DatasetQueryApi,
DatasetRelatedAppListApi,
DatasetRetrievalSettingApi,
DatasetRetrievalSettingMockApi,
DatasetUseCheckApi,
)
from controllers.console.datasets.error import DatasetInUseError, DatasetNameDuplicateError, IndexingEstimateError
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.provider_manager import ProviderManager
from models.enums import CreatorUserRole
from models.model import ApiToken, UploadFile
from services.dataset_service import DatasetPermissionService, DatasetService
def unwrap(func):
while hasattr(func, "__wrapped__"):
func = func.__wrapped__
return func
class TestDatasetList:
def _mock_dataset_dict(self, **overrides):
base = {
"id": "ds-1",
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"permission": "only_me",
}
base.update(overrides)
return base
def _mock_user(self):
user = MagicMock()
user.is_dataset_editor = True
return user
def test_get_success_basic(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert status == 200
assert resp["total"] == 1
assert resp["data"][0]["embedding_available"] is True
def test_get_with_ids_filter(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets?ids=1&ids=2"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets_by_ids",
return_value=(datasets, 2),
) as by_ids_mock,
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
by_ids_mock.assert_called_once()
assert status == 200
assert resp["total"] == 2
def test_get_with_tag_ids(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict()]
with app.test_request_context("/datasets?tag_ids=tag1"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert status == 200
def test_embedding_available_false(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [
self._mock_dataset_dict(
indexing_technique="high_quality",
embedding_model="text-embed",
embedding_model_provider="openai",
)
]
config = MagicMock()
config.get_models.return_value = [] # model not available
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=config,
),
):
resp, status = method(api)
assert resp["data"][0]["embedding_available"] is False
def test_partial_members_permission(self, app):
api = DatasetListApi()
method = unwrap(api.get)
current_user = self._mock_user()
datasets = [MagicMock()]
marshaled = [self._mock_dataset_dict(permission="partial_members")]
with app.test_request_context("/datasets"):
with (
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_datasets",
return_value=(datasets, 1),
),
patch(
"controllers.console.datasets.datasets.db.session.execute",
return_value=MagicMock(all=lambda: [("ds-1", "u1")]),
),
patch(
"controllers.console.datasets.datasets.marshal",
return_value=marshaled,
),
patch.object(
ProviderManager,
"get_configurations",
return_value=MagicMock(get_models=lambda **_: []),
),
):
resp, status = method(api)
assert resp["data"][0]["partial_member_list"] == ["u1"]
class TestDatasetListApiPost:
def test_post_success(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "My Dataset",
"description": "desc",
"indexing_technique": "economy",
"provider": "vendor",
}
user = MagicMock()
user.is_dataset_editor = True
dataset = MagicMock()
# ---- minimal required fields for marshal ----
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasetService,
"create_empty_dataset",
return_value=dataset,
),
):
_, status = method(api)
assert status == 201
def test_post_forbidden(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {"name": "test"}
user = MagicMock()
user.is_dataset_editor = False
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
):
with pytest.raises(Forbidden):
method(api)
def test_post_duplicate_name(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {"name": "duplicate"}
user = MagicMock()
user.is_dataset_editor = True
with (
app.test_request_context("/datasets", json=payload),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant-1"),
),
patch.object(
DatasetService,
"create_empty_dataset",
side_effect=services.errors.dataset.DatasetNameDuplicateError(),
),
):
with pytest.raises(DatasetNameDuplicateError):
method(api)
def test_post_invalid_payload_missing_name(self, app):
api = DatasetListApi()
method = unwrap(api.post)
with app.test_request_context("/datasets", json={}), patch.object(type(console_ns), "payload", {}):
with pytest.raises(ValueError):
method(api)
def test_post_invalid_indexing_technique(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "bad",
"indexing_technique": "invalid-tech",
}
with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(ValueError, match="Invalid indexing technique"):
method(api)
def test_post_invalid_provider(self, app):
api = DatasetListApi()
method = unwrap(api.post)
payload = {
"name": "bad",
"provider": "unknown",
}
with app.test_request_context("/datasets", json=payload), patch.object(type(console_ns), "payload", payload):
with pytest.raises(ValueError, match="Invalid provider"):
method(api)
class TestDatasetApiGet:
def test_get_success_basic(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "123e4567-e89b-12d3-a456-426614174000"
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
dataset.permission = "only_me"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
# embedding models exist → embedding_available stays True
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, status = method(api, dataset_id)
assert status == 200
assert data["embedding_available"] is True
def test_get_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "missing-id"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_permission_denied(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
dataset = MagicMock()
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden, match="no access"):
method(api, dataset_id)
def test_get_high_quality_embedding_unavailable(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "high_quality"
dataset.embedding_model = "text-embedding"
dataset.embedding_model_provider = "openai"
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
dataset.permission = "only_me"
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
# embedding model NOT configured
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, _ = method(api, dataset_id)
assert data["embedding_available"] is False
def test_get_partial_members_permission(self, app):
api = DatasetApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
dataset = MagicMock()
dataset.id = dataset_id
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.permission = "partial_members"
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
partial_members = [{"id": "u1"}, {"id": "u2"}]
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=partial_members,
),
patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock,
):
provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = []
data, _ = method(api, dataset_id)
assert data["partial_member_list"] == partial_members
class TestDatasetApiPatch:
def test_patch_success_basic(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"name": "updated-name",
"description": "updated description",
}
user = MagicMock()
tenant_id = "tenant-1"
dataset = MagicMock()
dataset.id = dataset_id
dataset.tenant_id = tenant_id
dataset.permission = "only_me"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, tenant_id),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=[],
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result["partial_member_list"] == []
def test_patch_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.patch)
with (
app.test_request_context("/datasets/missing"),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, "missing")
def test_patch_permission_denied(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
dataset = MagicMock()
payload = {"name": "x"}
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetPermissionService,
"check_permission",
side_effect=Forbidden("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_patch_partial_members_update(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"permission": "partial_members",
"partial_member_list": [{"id": "u1"}, {"id": "u2"}],
}
dataset = MagicMock()
dataset.id = dataset_id
dataset.permission = "partial_members"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"update_partial_member_list",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=payload["partial_member_list"],
),
):
result, _ = method(api, dataset_id)
assert result["partial_member_list"] == payload["partial_member_list"]
def test_patch_clear_partial_members(self, app):
api = DatasetApi()
method = unwrap(api.patch)
dataset_id = "dataset-id"
payload = {
"permission": "only_me",
}
dataset = MagicMock()
dataset.id = dataset_id
dataset.permission = "only_me"
dataset.indexing_technique = "economy"
dataset.embedding_model_provider = None
dataset.embedding_available = True
dataset.built_in_field_enabled = False
dataset.is_published = False
dataset.enable_api = False
dataset.is_multimodal = False
dataset.documents = []
dataset.retrieval_model_dict = {}
dataset.tags = []
dataset.external_knowledge_info = None
dataset.external_retrieval_model = None
dataset.doc_metadata = []
dataset.icon_info = None
dataset.summary_index_setting = MagicMock()
dataset.summary_index_setting.enable = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch.object(type(console_ns), "payload", payload),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"check_permission",
return_value=None,
),
patch.object(
DatasetService,
"update_dataset",
return_value=dataset,
),
patch.object(
DatasetPermissionService,
"clear_partial_member_list",
return_value=None,
),
patch.object(
DatasetPermissionService,
"get_dataset_partial_member_list",
return_value=[],
),
):
result, _ = method(api, dataset_id)
assert result["partial_member_list"] == []
class TestDatasetApiDelete:
def test_delete_success(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
return_value=True,
),
patch.object(
DatasetPermissionService,
"clear_partial_member_list",
return_value=None,
),
):
result, status = method(api, dataset_id)
assert status == 204
assert result == {"result": "success"}
def test_delete_forbidden_no_permission(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = False
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_delete_dataset_not_found(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "missing-dataset"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
return_value=False,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_delete_dataset_in_use(self, app):
api = DatasetApi()
method = unwrap(api.delete)
dataset_id = "dataset-id"
user = MagicMock()
user.has_edit_permission = True
user.is_dataset_operator = False
with (
app.test_request_context(f"/datasets/{dataset_id}"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(user, "tenant"),
),
patch.object(
DatasetService,
"delete_dataset",
side_effect=services.errors.dataset.DatasetInUseError(),
),
):
with pytest.raises(DatasetInUseError):
method(api, dataset_id)
class TestDatasetUseCheckApi:
def test_get_use_check_true(self, app):
api = DatasetUseCheckApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
with (
app.test_request_context(f"/datasets/{dataset_id}/use-check"),
patch.object(
DatasetService,
"dataset_use_check",
return_value=True,
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result == {"is_using": True}
def test_get_use_check_false(self, app):
api = DatasetUseCheckApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
with (
app.test_request_context(f"/datasets/{dataset_id}/use-check"),
patch.object(
DatasetService,
"dataset_use_check",
return_value=False,
),
):
result, status = method(api, dataset_id)
assert status == 200
assert result == {"is_using": False}
class TestDatasetQueryApi:
def test_get_queries_success(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
dataset.id = dataset_id
queries = [MagicMock(), MagicMock()]
with (
app.test_request_context("/datasets/queries?page=1&limit=20"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetService,
"get_dataset_queries",
return_value=(queries, 2),
),
):
response, status = method(api, dataset_id)
assert status == 200
assert response["total"] == 2
assert response["page"] == 1
assert response["limit"] == 20
assert response["has_more"] is False
assert len(response["data"]) == 2
def test_get_queries_dataset_not_found(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
with (
app.test_request_context("/datasets/queries"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound, match="Dataset not found"):
method(api, dataset_id)
def test_get_queries_permission_denied(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
with (
app.test_request_context("/datasets/queries"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no access"),
),
):
with pytest.raises(Forbidden):
method(api, dataset_id)
def test_get_queries_pagination_has_more(self, app):
api = DatasetQueryApi()
method = unwrap(api.get)
dataset_id = "dataset-id"
current_user = MagicMock()
dataset = MagicMock()
dataset.id = dataset_id
queries = [MagicMock() for _ in range(20)]
with (
app.test_request_context("/datasets/queries?page=1&limit=20"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(current_user, "tenant-1"),
),
patch.object(
DatasetService,
"get_dataset",
return_value=dataset,
),
patch.object(
DatasetService,
"check_dataset_permission",
return_value=None,
),
patch.object(
DatasetService,
"get_dataset_queries",
return_value=(queries, 40),
),
):
response, status = method(api, dataset_id)
assert status == 200
assert response["has_more"] is True
assert len(response["data"]) == 20
class TestDatasetIndexingEstimateApi:
def _upload_file(self, *, tenant_id: str = "tenant-1", file_id: str = "file-1") -> UploadFile:
upload_file = UploadFile(
tenant_id=tenant_id,
storage_type="local",
key="key",
name="name.txt",
size=1,
extension="txt",
mime_type="text/plain",
created_by_role=CreatorUserRole.ACCOUNT,
created_by="user-1",
created_at=datetime.datetime.now(tz=datetime.UTC),
used=False,
)
upload_file.id = file_id
return upload_file
def _base_payload(self):
return {
"info_list": {
"data_source_type": "upload_file",
"file_info_list": {
"file_ids": ["file-1"],
},
},
"process_rule": {"chunk_size": 100},
"indexing_technique": "high_quality",
"doc_form": "text_model",
"doc_language": "English",
"dataset_id": None,
}
def test_post_success_upload_file(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
payload = self._base_payload()
mock_file = self._upload_file()
mock_response = MagicMock()
mock_response.model_dump.return_value = {"tokens": 100}
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
return_value=mock_response,
),
):
response, status = method(api)
assert status == 200
assert response == {"tokens": 100}
def test_post_file_not_found(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: None),
),
):
with pytest.raises(NotFound):
method(api)
def test_post_llm_bad_request_error(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=LLMBadRequestError(),
),
):
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_provider_token_not_init(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=ProviderTokenNotInitError("token missing"),
),
):
with pytest.raises(ProviderNotInitializeError):
method(api)
def test_post_generic_exception(self, app):
api = DatasetIndexingEstimateApi()
method = unwrap(api.post)
mock_file = self._upload_file()
payload = self._base_payload()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch.object(
type(console_ns),
"payload",
new_callable=PropertyMock,
return_value=payload,
),
patch(
"controllers.console.datasets.datasets.DocumentService.estimate_args_validate",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_file]),
),
patch(
"controllers.console.datasets.datasets.IndexingRunner.indexing_estimate",
side_effect=Exception("boom"),
),
):
with pytest.raises(IndexingEstimateError):
method(api)
class TestDatasetRelatedAppListApi:
def test_get_success(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
dataset.id = "dataset-1"
app1 = MagicMock()
app2 = MagicMock()
join1 = MagicMock(app=app1)
join2 = MagicMock(app=app2)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_related_apps",
return_value=[join1, join2],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 2
assert response["data"] == [app1, app2]
def test_get_dataset_not_found(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")
def test_get_permission_denied(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, "dataset-1")
def test_get_filters_none_apps(self, app):
api = DatasetRelatedAppListApi()
method = unwrap(api.get)
dataset = MagicMock()
dataset.id = "dataset-1"
app1 = MagicMock()
join1 = MagicMock(app=app1)
join2 = MagicMock(app=None)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_related_apps",
return_value=[join1, join2],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 1
assert response["data"] == [app1]
class TestDatasetIndexingStatusApi:
def test_get_success_with_documents(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
document = MagicMock()
document.id = "doc-1"
document.indexing_status = "completed"
document.processing_started_at = None
document.parsing_completed_at = None
document.cleaning_completed_at = None
document.splitting_completed_at = None
document.completed_at = None
document.paused_at = None
document.error = None
document.stopped_at = None
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert "data" in response
assert len(response["data"]) == 1
item = response["data"][0]
assert item["completed_segments"] == 3
assert item["total_segments"] == 3
def test_get_success_no_documents(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: []),
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response == {"data": []}
def test_segment_counts_different_values(self, app):
api = DatasetIndexingStatusApi()
method = unwrap(api.get)
document = MagicMock()
document.id = "doc-1"
document.indexing_status = "indexing"
document.processing_started_at = None
document.parsing_completed_at = None
document.cleaning_completed_at = None
document.splitting_completed_at = None
document.completed_at = None
document.paused_at = None
document.error = None
document.stopped_at = None
# First count = completed segments, second = total segments
query_mock = MagicMock()
query_mock.where.side_effect = [
MagicMock(count=lambda: 2),
MagicMock(count=lambda: 5),
]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [document]),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=query_mock,
),
):
response, status = method(api, "dataset-1")
assert status == 200
item = response["data"][0]
assert item["completed_segments"] == 2
assert item["total_segments"] == 5
class TestDatasetApiKeyApi:
def test_get_api_keys_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.get)
mock_key_1 = MagicMock(spec=ApiToken)
mock_key_2 = MagicMock(spec=ApiToken)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.scalars",
return_value=MagicMock(all=lambda: [mock_key_1, mock_key_2]),
),
):
response = method(api)
assert "items" in response
assert response["items"] == [mock_key_1, mock_key_2]
def test_post_create_api_key_success(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 3)),
),
patch(
"controllers.console.datasets.datasets.ApiToken.generate_api_key",
return_value="dataset-abc123",
),
patch(
"controllers.console.datasets.datasets.db.session.add",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
return_value=None,
),
):
response, status = method(api)
assert status == 200
assert isinstance(response, ApiToken)
assert response.token == "dataset-abc123"
assert response.type == "dataset"
def test_post_exceed_max_keys(self, app):
api = DatasetApiKeyApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(count=lambda: 10)),
),
):
with pytest.raises(BadRequest) as exc_info:
method(api)
assert exc_info.value.code == 400
assert exc_info.value.data == {
"message": "Cannot create more than 10 API keys for this resource type.",
"custom": "max_keys_exceeded",
}
class TestDatasetApiDeleteApi:
def test_delete_success(self, app):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
mock_key = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: mock_key)),
),
patch(
"controllers.console.datasets.datasets.db.session.commit",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.db.session.delete",
return_value=None,
),
):
response, status = method(api, "api-key-id")
assert status == 204
assert response["result"] == "success"
def test_delete_key_not_found(self, app):
api = DatasetApiDeleteApi()
method = unwrap(api.delete)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.db.session.query",
return_value=MagicMock(where=lambda *args, **kwargs: MagicMock(first=lambda: None)),
),
):
with pytest.raises(NotFound):
method(api, "api-key-id")
class TestDatasetEnableApiApi:
def test_enable_api(self, app):
api = DatasetEnableApiApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.update_dataset_api_status",
return_value=None,
),
):
response, status = method(api, "dataset-1", "enable")
assert status == 200
assert response["result"] == "success"
def test_disable_api(self, app):
api = DatasetEnableApiApi()
method = unwrap(api.post)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.update_dataset_api_status",
return_value=None,
),
):
response, status = method(api, "dataset-1", "disable")
assert status == 200
assert response["result"] == "success"
class TestDatasetApiBaseUrlApi:
def test_get_api_base_url_from_config(self, app):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
"https://example.com",
),
):
response = method(api)
assert response["api_base_url"] == "https://example.com/v1"
def test_get_api_base_url_from_request(self, app):
api = DatasetApiBaseUrlApi()
method = unwrap(api.get)
with (
app.test_request_context("http://localhost:5000/"),
patch(
"controllers.console.datasets.datasets.dify_config.SERVICE_API_URL",
None,
),
):
response = method(api)
assert response["api_base_url"] == "http://localhost:5000/v1"
class TestDatasetRetrievalSettingApi:
def test_get_success(self, app):
api = DatasetRetrievalSettingApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.dify_config.VECTOR_STORE",
"qdrant",
),
patch(
"controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type",
return_value={"retrieval_method": ["semantic", "hybrid"]},
),
):
response = method(api)
assert "retrieval_method" in response
class TestDatasetRetrievalSettingMockApi:
def test_get_success(self, app):
api = DatasetRetrievalSettingMockApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets._get_retrieval_methods_by_vector_type",
return_value={"retrieval_method": ["semantic"]},
),
):
response = method(api, "milvus")
assert response["retrieval_method"] == ["semantic"]
class TestDatasetErrorDocs:
def test_get_success(self, app):
api = DatasetErrorDocs()
method = unwrap(api.get)
dataset = MagicMock()
error_doc = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DocumentService.get_error_documents_by_dataset_id",
return_value=[error_doc],
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["total"] == 1
def test_get_dataset_not_found(self, app):
api = DatasetErrorDocs()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")
class TestDatasetPermissionUserListApi:
def test_get_success(self, app):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
dataset = MagicMock()
users = [{"id": "u1"}, {"id": "u2"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
return_value=None,
),
patch(
"controllers.console.datasets.datasets.DatasetPermissionService.get_dataset_partial_member_list",
return_value=users,
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response["data"] == users
def test_get_permission_denied(self, app):
api = DatasetPermissionUserListApi()
method = unwrap(api.get)
dataset = MagicMock()
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.current_account_with_tenant",
return_value=(MagicMock(), "tenant-1"),
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.check_dataset_permission",
side_effect=services.errors.account.NoPermissionError("no permission"),
),
):
with pytest.raises(Forbidden):
method(api, "dataset-1")
class TestDatasetAutoDisableLogApi:
def test_get_success(self, app):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
dataset = MagicMock()
logs = [{"reason": "quota"}]
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=dataset,
),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset_auto_disable_logs",
return_value=logs,
),
):
response, status = method(api, "dataset-1")
assert status == 200
assert response == logs
def test_get_dataset_not_found(self, app):
api = DatasetAutoDisableLogApi()
method = unwrap(api.get)
with (
app.test_request_context("/"),
patch(
"controllers.console.datasets.datasets.DatasetService.get_dataset",
return_value=None,
),
):
with pytest.raises(NotFound):
method(api, "dataset-1")