mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 18:09:57 +08:00
1753 lines
80 KiB
Python
1753 lines
80 KiB
Python
"""Unit tests for DatasetService and dataset-related collaborators."""
|
|
|
|
from .dataset_service_test_helpers import (
|
|
CloudPlan,
|
|
Dataset,
|
|
DatasetCollectionBindingService,
|
|
DatasetNameDuplicateError,
|
|
DatasetPermissionEnum,
|
|
DatasetPermissionService,
|
|
DatasetProcessRule,
|
|
DatasetService,
|
|
DatasetServiceUnitDataFactory,
|
|
DocumentIndexingError,
|
|
DocumentService,
|
|
LLMBadRequestError,
|
|
MagicMock,
|
|
Mock,
|
|
ModelFeature,
|
|
ModelType,
|
|
NoPermissionError,
|
|
NotFound,
|
|
PipelineIconInfo,
|
|
ProviderTokenNotInitError,
|
|
RagPipelineDatasetCreateEntity,
|
|
SimpleNamespace,
|
|
TenantAccountRole,
|
|
_make_knowledge_configuration,
|
|
_make_retrieval_model,
|
|
_make_session_context,
|
|
json,
|
|
patch,
|
|
pytest,
|
|
)
|
|
|
|
|
|
class TestDatasetServiceQueries:
|
|
"""Unit tests for DatasetService query composition and fallback branches."""
|
|
|
|
@pytest.fixture
|
|
def mock_dataset_query_dependencies(self):
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like,
|
|
patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids,
|
|
):
|
|
mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1)
|
|
yield {
|
|
"db": mock_db,
|
|
"escape_like_pattern": escape_like,
|
|
"get_target_ids": get_target_ids,
|
|
}
|
|
|
|
def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies):
|
|
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1")
|
|
|
|
assert items == ["dataset"]
|
|
assert total == 1
|
|
mock_dataset_query_dependencies["db"].paginate.assert_called_once()
|
|
mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called()
|
|
|
|
def test_get_datasets_short_circuits_for_dataset_operator_without_permissions(
|
|
self, mock_dataset_query_dependencies
|
|
):
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR)
|
|
mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = []
|
|
|
|
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user)
|
|
|
|
assert items == []
|
|
assert total == 0
|
|
mock_dataset_query_dependencies["db"].paginate.assert_not_called()
|
|
|
|
def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies):
|
|
mock_dataset_query_dependencies["get_target_ids"].return_value = []
|
|
|
|
items, total = DatasetService.get_datasets(
|
|
page=1,
|
|
per_page=20,
|
|
tenant_id="tenant-1",
|
|
tag_ids=["tag-1"],
|
|
)
|
|
|
|
assert items == []
|
|
assert total == 0
|
|
mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"])
|
|
mock_dataset_query_dependencies["db"].paginate.assert_not_called()
|
|
|
|
def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies):
|
|
mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"]
|
|
|
|
items, total = DatasetService.get_datasets(
|
|
page=2,
|
|
per_page=10,
|
|
tenant_id="tenant-1",
|
|
search="report",
|
|
tag_ids=["tag-1"],
|
|
)
|
|
|
|
assert items == ["dataset"]
|
|
assert total == 1
|
|
mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report")
|
|
mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"])
|
|
mock_dataset_query_dependencies["db"].paginate.assert_called_once()
|
|
|
|
def test_get_process_rules_returns_latest_rule_when_present(self):
|
|
dataset_process_rule = Mock(spec=DatasetProcessRule)
|
|
dataset_process_rule.mode = "automatic"
|
|
dataset_process_rule.rules_dict = {"delimiter": "\n"}
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
(
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
|
|
) = dataset_process_rule
|
|
|
|
result = DatasetService.get_process_rules("dataset-1")
|
|
|
|
assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}}
|
|
|
|
def test_get_process_rules_falls_back_to_default_rules_when_missing(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
(
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
|
|
) = None
|
|
|
|
result = DatasetService.get_process_rules("dataset-1")
|
|
|
|
assert result == {
|
|
"mode": DocumentService.DEFAULT_RULES["mode"],
|
|
"rules": DocumentService.DEFAULT_RULES["rules"],
|
|
}
|
|
|
|
def test_get_datasets_by_ids_returns_empty_for_missing_ids(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
items, total = DatasetService.get_datasets_by_ids([], "tenant-1")
|
|
|
|
assert items == []
|
|
assert total == 0
|
|
mock_db.paginate.assert_not_called()
|
|
|
|
def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1)
|
|
|
|
items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1")
|
|
|
|
assert items == ["dataset-1"]
|
|
assert total == 1
|
|
mock_db.paginate.assert_called_once()
|
|
|
|
def test_get_dataset_returns_first_match(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset
|
|
|
|
result = DatasetService.get_dataset(dataset.id)
|
|
|
|
assert result is dataset
|
|
|
|
|
|
class TestDatasetServiceValidation:
|
|
"""Unit tests for DatasetService validation helpers."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("dataset_doc_form", "incoming_doc_form"),
|
|
[(None, "text_model"), ("text_model", "text_model")],
|
|
)
|
|
def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form)
|
|
|
|
DatasetService.check_doc_form(dataset, incoming_doc_form)
|
|
|
|
def test_check_doc_form_rejects_mismatched_doc_form(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model")
|
|
|
|
with pytest.raises(ValueError, match="doc_form is different"):
|
|
DatasetService.check_doc_form(dataset, "text_model")
|
|
|
|
def test_check_dataset_model_setting_skips_non_high_quality_datasets(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy")
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
|
model_manager_cls.assert_not_called()
|
|
|
|
def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
|
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
|
tenant_id=dataset.tenant_id,
|
|
provider=dataset.embedding_model_provider,
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
model=dataset.embedding_model,
|
|
)
|
|
|
|
def test_check_dataset_model_setting_wraps_llm_bad_request_error(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
|
|
|
with pytest.raises(ValueError, match="No Embedding Model available"):
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
|
def test_check_dataset_model_setting_wraps_provider_token_error(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
|
|
|
with pytest.raises(ValueError, match="token missing"):
|
|
DatasetService.check_dataset_model_setting(dataset)
|
|
|
|
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup")
|
|
|
|
with pytest.raises(ValueError, match="provider setup"):
|
|
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
|
|
|
|
def test_check_reranking_model_setting_uses_rerank_model_type(self):
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
|
|
|
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
|
tenant_id="tenant-1",
|
|
provider="provider",
|
|
model_type=ModelType.RERANK,
|
|
model="reranker",
|
|
)
|
|
|
|
def test_check_reranking_model_setting_wraps_bad_request(self):
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
|
|
|
with pytest.raises(ValueError, match="No Rerank Model available"):
|
|
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
|
|
|
def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self):
|
|
model_schema = SimpleNamespace(features=[ModelFeature.VISION])
|
|
model_type_instance = MagicMock()
|
|
model_type_instance.get_model_schema.return_value = model_schema
|
|
model_instance = SimpleNamespace(
|
|
model_type_instance=model_type_instance,
|
|
model_name="embedding-model",
|
|
credentials={"api_key": "secret"},
|
|
)
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
|
|
|
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
|
|
|
assert result is True
|
|
|
|
def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self):
|
|
model_schema = SimpleNamespace(features=[])
|
|
model_type_instance = MagicMock()
|
|
model_type_instance.get_model_schema.return_value = model_schema
|
|
model_instance = SimpleNamespace(
|
|
model_type_instance=model_type_instance,
|
|
model_name="embedding-model",
|
|
credentials={"api_key": "secret"},
|
|
)
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
|
|
|
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
|
|
|
assert result is False
|
|
|
|
def test_check_is_multimodal_model_raises_when_schema_is_missing(self):
|
|
model_type_instance = MagicMock()
|
|
model_type_instance.get_model_schema.return_value = None
|
|
model_instance = SimpleNamespace(
|
|
model_type_instance=model_type_instance,
|
|
model_name="embedding-model",
|
|
credentials={"api_key": "secret"},
|
|
)
|
|
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
|
|
|
with pytest.raises(ValueError, match="Model schema not found"):
|
|
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
|
|
|
def test_check_is_multimodal_model_wraps_bad_request_error(self):
|
|
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
|
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
|
|
|
with pytest.raises(ValueError, match="No Model available"):
|
|
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
|
|
|
|
|
class TestDatasetServiceCreationAndUpdate:
|
|
"""Unit tests for dataset creation and update helpers."""
|
|
|
|
def test_create_empty_dataset_raises_when_name_already_exists(self):
|
|
account = SimpleNamespace(id="user-1")
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
|
|
|
with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"):
|
|
DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account)
|
|
|
|
def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self):
|
|
account = SimpleNamespace(id="user-1")
|
|
default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding")
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch(
|
|
"services.dataset_service.Dataset",
|
|
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
|
|
),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model
|
|
|
|
dataset = DatasetService.create_empty_dataset(
|
|
tenant_id="tenant-1",
|
|
name="Dataset",
|
|
description="Description",
|
|
indexing_technique="high_quality",
|
|
account=account,
|
|
)
|
|
|
|
assert dataset.embedding_model_provider == "provider"
|
|
assert dataset.embedding_model == "default-embedding"
|
|
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
|
|
assert dataset.provider == "vendor"
|
|
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
|
tenant_id="tenant-1",
|
|
model_type=ModelType.TEXT_EMBEDDING,
|
|
)
|
|
check_embedding.assert_not_called()
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self):
|
|
account = SimpleNamespace(id="user-1")
|
|
retrieval_model = _make_retrieval_model()
|
|
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch(
|
|
"services.dataset_service.Dataset",
|
|
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
|
|
),
|
|
patch(
|
|
"services.dataset_service.ExternalKnowledgeBindings",
|
|
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
|
|
) as binding_cls,
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()),
|
|
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
|
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
dataset = DatasetService.create_empty_dataset(
|
|
tenant_id="tenant-1",
|
|
name="External Dataset",
|
|
description="Description",
|
|
indexing_technique="high_quality",
|
|
account=account,
|
|
permission=DatasetPermissionEnum.ALL_TEAM,
|
|
provider="external",
|
|
external_knowledge_api_id="api-1",
|
|
external_knowledge_id="knowledge-1",
|
|
embedding_model_provider="provider",
|
|
embedding_model_name="embedding-model",
|
|
retrieval_model=retrieval_model,
|
|
summary_index_setting={"enable": True},
|
|
)
|
|
|
|
assert dataset.embedding_model_provider == "provider"
|
|
assert dataset.embedding_model == "embedding-model"
|
|
assert dataset.retrieval_model == retrieval_model.model_dump()
|
|
assert dataset.summary_index_setting == {"enable": True}
|
|
check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model")
|
|
check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model")
|
|
binding_cls.assert_called_once_with(
|
|
tenant_id="tenant-1",
|
|
dataset_id="dataset-1",
|
|
external_knowledge_api_id="api-1",
|
|
external_knowledge_id="knowledge-1",
|
|
created_by="user-1",
|
|
)
|
|
assert mock_db.session.add.call_count == 2
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self):
|
|
entity = RagPipelineDatasetCreateEntity(
|
|
name="Existing Dataset",
|
|
description="Description",
|
|
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
|
|
permission=DatasetPermissionEnum.ALL_TEAM,
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
|
|
|
with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"):
|
|
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
|
|
|
|
def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self):
|
|
entity = RagPipelineDatasetCreateEntity(
|
|
name="",
|
|
description="Description",
|
|
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
|
|
permission=DatasetPermissionEnum.ALL_TEAM,
|
|
)
|
|
pipeline = SimpleNamespace(id="pipeline-1")
|
|
|
|
def pipeline_factory(**kwargs):
|
|
pipeline.__dict__.update(kwargs)
|
|
return pipeline
|
|
|
|
def dataset_factory(**kwargs):
|
|
return SimpleNamespace(id="dataset-1", **kwargs)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
|
|
patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name,
|
|
patch("services.dataset_service.Pipeline", side_effect=pipeline_factory),
|
|
patch("services.dataset_service.Dataset", side_effect=dataset_factory),
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.all.return_value = [
|
|
SimpleNamespace(name="Untitled"),
|
|
SimpleNamespace(name="Untitled 1"),
|
|
]
|
|
|
|
dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
|
|
|
|
assert entity.name == "Untitled 2"
|
|
assert dataset.pipeline_id == "pipeline-1"
|
|
assert dataset.runtime_mode == "rag_pipeline"
|
|
generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled")
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self):
|
|
entity = RagPipelineDatasetCreateEntity(
|
|
name="Dataset",
|
|
description="Description",
|
|
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
|
|
permission=DatasetPermissionEnum.ALL_TEAM,
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
|
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
|
|
|
|
def test_update_dataset_raises_when_dataset_is_missing(self):
|
|
with patch.object(DatasetService, "get_dataset", return_value=None):
|
|
with pytest.raises(ValueError, match="Dataset not found"):
|
|
DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1"))
|
|
|
|
def test_update_dataset_raises_when_new_name_conflicts(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
|
|
dataset.name = "Old Dataset"
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch.object(DatasetService, "_has_dataset_same_name", return_value=True),
|
|
):
|
|
with pytest.raises(ValueError, match="Dataset name already exists"):
|
|
DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1"))
|
|
|
|
def test_update_dataset_routes_external_datasets_to_external_helper(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
|
|
dataset.provider = "external"
|
|
user = DatasetServiceUnitDataFactory.create_user_mock()
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch.object(DatasetService, "check_dataset_permission") as check_permission,
|
|
patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external,
|
|
):
|
|
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
|
|
|
|
assert result == "updated"
|
|
check_permission.assert_called_once_with(dataset, user)
|
|
update_external.assert_called_once_with(dataset, {"name": dataset.name}, user)
|
|
|
|
def test_update_dataset_routes_internal_datasets_to_internal_helper(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
|
|
dataset.provider = "vendor"
|
|
user = DatasetServiceUnitDataFactory.create_user_mock()
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch.object(DatasetService, "check_dataset_permission") as check_permission,
|
|
patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal,
|
|
):
|
|
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
|
|
|
|
assert result == "updated"
|
|
check_permission.assert_called_once_with(dataset, user)
|
|
update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user)
|
|
|
|
def test_has_dataset_same_name_returns_true_when_query_matches(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.first.return_value = object()
|
|
|
|
result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset")
|
|
|
|
assert result is True
|
|
|
|
def test_update_external_dataset_updates_dataset_and_binding(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
user = SimpleNamespace(id="user-1")
|
|
now = object()
|
|
|
|
with (
|
|
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
|
|
patch("services.dataset_service.naive_utc_now", return_value=now),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
):
|
|
result = DatasetService._update_external_dataset(
|
|
dataset,
|
|
{
|
|
"external_retrieval_model": {"top_k": 3},
|
|
"summary_index_setting": {"enable": True},
|
|
"name": "Updated Dataset",
|
|
"description": "Updated description",
|
|
"permission": DatasetPermissionEnum.PARTIAL_TEAM,
|
|
"external_knowledge_id": "knowledge-1",
|
|
"external_knowledge_api_id": "api-1",
|
|
},
|
|
user,
|
|
)
|
|
|
|
assert result is dataset
|
|
assert dataset.retrieval_model == {"top_k": 3}
|
|
assert dataset.summary_index_setting == {"enable": True}
|
|
assert dataset.name == "Updated Dataset"
|
|
assert dataset.description == "Updated description"
|
|
assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM
|
|
assert dataset.updated_by == "user-1"
|
|
assert dataset.updated_at is now
|
|
update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1")
|
|
mock_db.session.add.assert_called_once_with(dataset)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
@pytest.mark.parametrize(
|
|
("payload", "message"),
|
|
[
|
|
({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"),
|
|
({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"),
|
|
],
|
|
)
|
|
def test_update_external_dataset_requires_external_binding_fields(self, payload, message):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
|
|
with pytest.raises(ValueError, match=message):
|
|
DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"))
|
|
|
|
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
|
|
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
|
|
session = MagicMock()
|
|
session.query.return_value.filter_by.return_value.first.return_value = binding
|
|
session_context = _make_session_context(session)
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.Session", return_value=session_context),
|
|
):
|
|
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
|
|
|
|
assert binding.external_knowledge_id == "new-knowledge"
|
|
assert binding.external_knowledge_api_id == "new-api"
|
|
mock_db.session.add.assert_called_once_with(binding)
|
|
|
|
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
|
|
session = MagicMock()
|
|
session.query.return_value.filter_by.return_value.first.return_value = None
|
|
session_context = _make_session_context(session)
|
|
|
|
with (
|
|
patch("services.dataset_service.db"),
|
|
patch("services.dataset_service.Session", return_value=session_context),
|
|
):
|
|
with pytest.raises(ValueError, match="External knowledge binding not found"):
|
|
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
|
|
|
|
def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
user = SimpleNamespace(id="user-1")
|
|
now = object()
|
|
update_payload = {
|
|
"name": "Updated Dataset",
|
|
"description": None,
|
|
"partial_member_list": [{"user_id": "member-1"}],
|
|
"external_knowledge_api_id": "api-1",
|
|
"external_knowledge_id": "knowledge-1",
|
|
"external_retrieval_model": {"top_k": 2},
|
|
"retrieval_model": {"top_k": 4},
|
|
"summary_index_setting": {"enable": True},
|
|
"icon_info": {"icon": "book"},
|
|
}
|
|
|
|
with (
|
|
patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"),
|
|
patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline,
|
|
patch("services.dataset_service.naive_utc_now", return_value=now),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task,
|
|
patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task,
|
|
):
|
|
result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user)
|
|
|
|
assert result is dataset
|
|
updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0]
|
|
assert updated_values["name"] == "Updated Dataset"
|
|
assert updated_values["description"] is None
|
|
assert updated_values["retrieval_model"] == {"top_k": 4}
|
|
assert updated_values["summary_index_setting"] == {"enable": True}
|
|
assert updated_values["icon_info"] == {"icon": "book"}
|
|
assert updated_values["updated_by"] == "user-1"
|
|
assert updated_values["updated_at"] is now
|
|
assert "partial_member_list" not in updated_values
|
|
assert "external_knowledge_api_id" not in updated_values
|
|
assert "external_knowledge_id" not in updated_values
|
|
assert "external_retrieval_model" not in updated_values
|
|
mock_db.session.commit.assert_called_once()
|
|
mock_db.session.refresh.assert_called_once_with(dataset)
|
|
update_pipeline.assert_called_once_with(dataset, "user-1")
|
|
vector_task.delay.assert_called_once_with("dataset-1", "update")
|
|
regenerate_task.delay.assert_called_once_with(
|
|
"dataset-1",
|
|
regenerate_reason="embedding_model_changed",
|
|
regenerate_vectors_only=True,
|
|
)
|
|
|
|
def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self):
|
|
dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1")
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
|
|
|
mock_db.session.query.assert_not_called()
|
|
|
|
def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self):
|
|
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
|
|
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
|
|
|
mock_db.session.commit.assert_not_called()
|
|
|
|
def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self):
|
|
dataset = SimpleNamespace(
|
|
id="dataset-1",
|
|
runtime_mode="rag_pipeline",
|
|
pipeline_id="pipeline-1",
|
|
embedding_model="embedding-model",
|
|
embedding_model_provider="provider",
|
|
retrieval_model={"top_k": 5},
|
|
chunk_structure="paragraph",
|
|
indexing_technique="high_quality",
|
|
keyword_number=8,
|
|
summary_index_setting={"enable": True},
|
|
)
|
|
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
|
|
published_workflow = SimpleNamespace(
|
|
graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}),
|
|
type="chat",
|
|
features={"feature": True},
|
|
environment_variables=[],
|
|
conversation_variables=[],
|
|
rag_pipeline_variables=[],
|
|
)
|
|
draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]}))
|
|
new_workflow = SimpleNamespace(id="workflow-1")
|
|
rag_pipeline_service = MagicMock()
|
|
rag_pipeline_service.get_published_workflow.return_value = published_workflow
|
|
rag_pipeline_service.get_draft_workflow.return_value = draft_workflow
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
|
|
patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new,
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
|
|
|
|
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
|
|
|
published_graph = json.loads(workflow_new.call_args.kwargs["graph"])
|
|
assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model"
|
|
assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True}
|
|
assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider"
|
|
mock_db.session.add.assert_any_call(new_workflow)
|
|
mock_db.session.add.assert_any_call(draft_workflow)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self):
|
|
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
|
|
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
|
|
rag_pipeline_service = MagicMock()
|
|
rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom")
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
|
|
):
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
|
|
|
|
mock_db.session.rollback.assert_called_once()
|
|
|
|
def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self):
|
|
filtered_data: dict[str, object] = {}
|
|
dataset = SimpleNamespace(indexing_technique="economy")
|
|
|
|
result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data)
|
|
|
|
assert result is None
|
|
assert filtered_data == {}
|
|
|
|
def test_handle_indexing_technique_change_switches_to_economy(self):
|
|
filtered_data: dict[str, object] = {}
|
|
dataset = SimpleNamespace(indexing_technique="high_quality")
|
|
|
|
result = DatasetService._handle_indexing_technique_change(
|
|
dataset,
|
|
{"indexing_technique": "economy"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert result == "remove"
|
|
assert filtered_data == {
|
|
"embedding_model": None,
|
|
"embedding_model_provider": None,
|
|
"collection_binding_id": None,
|
|
}
|
|
|
|
def test_handle_indexing_technique_change_switches_to_high_quality(self):
|
|
filtered_data: dict[str, object] = {}
|
|
dataset = SimpleNamespace(indexing_technique="economy")
|
|
|
|
with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding:
|
|
result = DatasetService._handle_indexing_technique_change(
|
|
dataset,
|
|
{"indexing_technique": "high_quality"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert result == "add"
|
|
configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data)
|
|
|
|
def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self):
|
|
filtered_data: dict[str, object] = {}
|
|
dataset = SimpleNamespace(indexing_technique="high_quality")
|
|
|
|
with patch.object(
|
|
DatasetService,
|
|
"_handle_embedding_model_update_when_technique_unchanged",
|
|
return_value="update",
|
|
) as update_embedding:
|
|
result = DatasetService._handle_indexing_technique_change(
|
|
dataset,
|
|
{"indexing_technique": "high_quality"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert result == "update"
|
|
update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data)
|
|
|
|
def test_configure_embedding_model_for_high_quality_updates_filtered_data(self):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
|
|
filtered_data: dict[str, object] = {}
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch(
|
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
|
|
return_value=SimpleNamespace(id="binding-1"),
|
|
),
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
DatasetService._configure_embedding_model_for_high_quality(
|
|
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert filtered_data == {
|
|
"embedding_model": "embedding-model",
|
|
"embedding_model_provider": "provider",
|
|
"collection_binding_id": "binding-1",
|
|
}
|
|
|
|
@pytest.mark.parametrize(
|
|
("error", "message"),
|
|
[
|
|
(LLMBadRequestError(), "No Embedding Model available"),
|
|
(ProviderTokenNotInitError("provider setup"), "provider setup"),
|
|
],
|
|
)
|
|
def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.side_effect = error
|
|
|
|
with pytest.raises(ValueError, match=message):
|
|
DatasetService._configure_embedding_model_for_high_quality(
|
|
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
|
{},
|
|
)
|
|
|
|
def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
)
|
|
filtered_data: dict[str, object] = {}
|
|
|
|
with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings:
|
|
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
|
|
dataset,
|
|
{},
|
|
filtered_data,
|
|
)
|
|
|
|
assert result is None
|
|
preserve_settings.assert_called_once_with(dataset, filtered_data)
|
|
|
|
def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
)
|
|
|
|
with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings:
|
|
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
|
|
dataset,
|
|
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
|
|
{},
|
|
)
|
|
|
|
assert result == "update"
|
|
update_settings.assert_called_once()
|
|
|
|
def test_preserve_existing_embedding_settings_keeps_current_binding(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
collection_binding_id="binding-1",
|
|
)
|
|
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
|
|
|
|
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
|
|
|
|
assert filtered_data == {
|
|
"embedding_model_provider": "provider",
|
|
"embedding_model": "embedding-model",
|
|
"collection_binding_id": "binding-1",
|
|
}
|
|
|
|
def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider=None,
|
|
embedding_model=None,
|
|
collection_binding_id=None,
|
|
)
|
|
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
|
|
|
|
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
|
|
|
|
assert filtered_data == {}
|
|
|
|
def test_update_embedding_model_settings_returns_update_for_changed_values(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
)
|
|
|
|
with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings:
|
|
result = DatasetService._update_embedding_model_settings(
|
|
dataset,
|
|
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
|
|
{},
|
|
)
|
|
|
|
assert result == "update"
|
|
apply_settings.assert_called_once()
|
|
|
|
def test_update_embedding_model_settings_returns_none_for_unchanged_values(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
)
|
|
|
|
result = DatasetService._update_embedding_model_settings(
|
|
dataset,
|
|
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
|
{},
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_update_embedding_model_settings_wraps_bad_request_errors(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
)
|
|
|
|
with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()):
|
|
with pytest.raises(ValueError, match="No Embedding Model available"):
|
|
DatasetService._update_embedding_model_settings(
|
|
dataset,
|
|
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
|
|
{},
|
|
)
|
|
|
|
def test_apply_new_embedding_settings_updates_binding_for_new_model(self):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1")
|
|
filtered_data: dict[str, object] = {}
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch(
|
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
|
|
return_value=SimpleNamespace(id="binding-2"),
|
|
),
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
|
provider="provider-two",
|
|
model_name="embedding-model-two",
|
|
)
|
|
|
|
DatasetService._apply_new_embedding_settings(
|
|
dataset,
|
|
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert filtered_data == {
|
|
"embedding_model": "embedding-model-two",
|
|
"embedding_model_provider": "provider-two",
|
|
"collection_binding_id": "binding-2",
|
|
}
|
|
|
|
def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
embedding_model_provider="provider",
|
|
embedding_model="embedding-model",
|
|
collection_binding_id="binding-1",
|
|
)
|
|
filtered_data: dict[str, object] = {}
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
|
|
|
DatasetService._apply_new_embedding_settings(
|
|
dataset,
|
|
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
|
|
filtered_data,
|
|
)
|
|
|
|
assert filtered_data == {
|
|
"embedding_model_provider": "provider",
|
|
"embedding_model": "embedding-model",
|
|
"collection_binding_id": "binding-1",
|
|
}
|
|
|
|
@pytest.mark.parametrize(
|
|
("summary_index_setting", "expected"),
|
|
[
|
|
(None, False),
|
|
({"enable": False}, False),
|
|
({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False),
|
|
({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True),
|
|
],
|
|
)
|
|
def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
dataset_id="dataset-1",
|
|
summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"},
|
|
)
|
|
|
|
result = DatasetService._check_summary_index_setting_model_changed(
|
|
dataset,
|
|
{"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {},
|
|
)
|
|
|
|
assert result is expected
|
|
|
|
|
|
class TestDatasetServiceRagPipelineSettings:
|
|
"""Unit tests for rag-pipeline dataset setting updates."""
|
|
|
|
def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
knowledge_configuration = _make_knowledge_configuration()
|
|
|
|
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)):
|
|
with pytest.raises(ValueError, match="Current user or current tenant not found"):
|
|
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
|
|
|
def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True})
|
|
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal,
|
|
patch(
|
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
|
|
return_value=SimpleNamespace(id="binding-1"),
|
|
),
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
|
|
|
assert dataset.chunk_structure == "paragraph"
|
|
assert dataset.indexing_technique == "high_quality"
|
|
assert dataset.embedding_model == "embedding-model"
|
|
assert dataset.embedding_model_provider == "provider"
|
|
assert dataset.collection_binding_id == "binding-1"
|
|
assert dataset.is_multimodal is True
|
|
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
|
|
assert dataset.summary_index_setting == {"enable": True}
|
|
check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model")
|
|
session.add.assert_called_once_with(dataset)
|
|
session.commit.assert_not_called()
|
|
|
|
def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(
|
|
indexing_technique="economy",
|
|
embedding_model_provider="",
|
|
embedding_model="",
|
|
keyword_number=12,
|
|
)
|
|
|
|
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
|
|
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
|
|
|
assert dataset.indexing_technique == "economy"
|
|
assert dataset.keyword_number == 12
|
|
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
|
|
session.add.assert_called_once_with(dataset)
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence")
|
|
|
|
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
|
|
with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"):
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
dataset.indexing_technique = "high_quality"
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(
|
|
indexing_technique="economy",
|
|
embedding_model_provider="",
|
|
embedding_model="",
|
|
)
|
|
|
|
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
|
|
with pytest.raises(
|
|
ValueError,
|
|
match="Knowledge base indexing technique is not allowed to be updated to economy",
|
|
):
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
dataset.indexing_technique = "economy"
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration()
|
|
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch.object(DatasetService, "check_is_multimodal_model", return_value=False),
|
|
patch(
|
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
|
|
return_value=SimpleNamespace(id="binding-1"),
|
|
),
|
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
|
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
assert dataset.indexing_technique == "high_quality"
|
|
assert dataset.embedding_model == "embedding-model"
|
|
assert dataset.embedding_model_provider == "provider"
|
|
assert dataset.collection_binding_id == "binding-1"
|
|
assert dataset.is_multimodal is False
|
|
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
|
|
session.add.assert_called_once_with(dataset)
|
|
session.commit.assert_called_once()
|
|
update_task.delay.assert_called_once_with("dataset-1", "add")
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
dataset.indexing_technique = "high_quality"
|
|
dataset.embedding_model_provider = "provider"
|
|
dataset.embedding_model = "embedding-model"
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(
|
|
embedding_model_provider="provider-two",
|
|
embedding_model="embedding-model-two",
|
|
summary_index_setting={"enable": True},
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch.object(DatasetService, "check_is_multimodal_model", return_value=True),
|
|
patch(
|
|
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
|
|
return_value=SimpleNamespace(id="binding-2"),
|
|
),
|
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
|
provider="provider-two",
|
|
model_name="embedding-model-two",
|
|
)
|
|
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
assert dataset.embedding_model_provider == "provider-two"
|
|
assert dataset.embedding_model == "embedding-model-two"
|
|
assert dataset.collection_binding_id == "binding-2"
|
|
assert dataset.is_multimodal is True
|
|
assert dataset.summary_index_setting == {"enable": True}
|
|
session.add.assert_called_once_with(dataset)
|
|
session.commit.assert_called_once()
|
|
update_task.delay.assert_called_once_with("dataset-1", "update")
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
dataset.indexing_technique = "high_quality"
|
|
dataset.embedding_model_provider = "provider"
|
|
dataset.embedding_model = "embedding-model"
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(
|
|
embedding_model_provider="provider-two",
|
|
embedding_model="embedding-model-two",
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
|
|
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
|
):
|
|
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
|
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
assert dataset.embedding_model_provider == "provider"
|
|
assert dataset.embedding_model == "embedding-model"
|
|
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
|
|
session.add.assert_called_once_with(dataset)
|
|
session.commit.assert_called_once()
|
|
update_task.delay.assert_called_once_with("dataset-1", "update")
|
|
|
|
def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self):
|
|
session = MagicMock()
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
|
|
dataset.chunk_structure = "paragraph"
|
|
dataset.indexing_technique = "economy"
|
|
dataset.keyword_number = 5
|
|
session.merge.return_value = dataset
|
|
knowledge_configuration = _make_knowledge_configuration(
|
|
indexing_technique="economy",
|
|
embedding_model_provider="",
|
|
embedding_model="",
|
|
keyword_number=9,
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
|
|
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
|
):
|
|
DatasetService.update_rag_pipeline_dataset_settings(
|
|
session,
|
|
dataset,
|
|
knowledge_configuration,
|
|
has_published=True,
|
|
)
|
|
|
|
assert dataset.keyword_number == 9
|
|
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
|
|
session.add.assert_called_once_with(dataset)
|
|
session.commit.assert_called_once()
|
|
update_task.delay.assert_not_called()
|
|
|
|
|
|
class TestDatasetServicePermissionsAndLifecycle:
|
|
"""Unit tests for dataset permissions, deletion, and metadata helpers."""
|
|
|
|
def test_delete_dataset_returns_false_when_dataset_is_missing(self):
|
|
with patch.object(DatasetService, "get_dataset", return_value=None):
|
|
result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1"))
|
|
|
|
assert result is False
|
|
|
|
def test_delete_dataset_checks_permission_and_deletes_dataset(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch.object(DatasetService, "check_dataset_permission") as check_permission,
|
|
patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal,
|
|
patch("services.dataset_service.db") as mock_db,
|
|
):
|
|
result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1"))
|
|
|
|
assert result is True
|
|
check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1"))
|
|
send_deleted_signal.assert_called_once_with(dataset)
|
|
mock_db.session.delete.assert_called_once_with(dataset)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_dataset_use_check_returns_scalar_result(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.execute.return_value.scalar_one.return_value = True
|
|
|
|
result = DatasetService.dataset_use_check("dataset-1")
|
|
|
|
assert result is True
|
|
|
|
def test_check_dataset_permission_rejects_cross_tenant_access(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a")
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b")
|
|
|
|
with pytest.raises(NoPermissionError, match="do not have permission"):
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
permission=DatasetPermissionEnum.ONLY_ME,
|
|
created_by="owner-1",
|
|
)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="member-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with pytest.raises(NoPermissionError, match="do not have permission"):
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
def test_check_dataset_permission_rejects_partial_team_user_without_binding(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
|
created_by="owner-1",
|
|
)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="member-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
|
|
|
with pytest.raises(NoPermissionError, match="do not have permission"):
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
|
created_by="creator-1",
|
|
)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="creator-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
mock_db.session.query.assert_not_called()
|
|
|
|
def test_check_dataset_permission_allows_partial_team_member_with_binding(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
permission=DatasetPermissionEnum.PARTIAL_TEAM,
|
|
created_by="owner-1",
|
|
)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="member-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
|
|
|
DatasetService.check_dataset_permission(dataset, user)
|
|
|
|
def test_check_dataset_operator_permission_validates_required_arguments(self):
|
|
with pytest.raises(ValueError, match="Dataset not found"):
|
|
DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None)
|
|
|
|
with pytest.raises(ValueError, match="User not found"):
|
|
DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1"))
|
|
|
|
def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
permission=DatasetPermissionEnum.ONLY_ME,
|
|
created_by="owner-1",
|
|
)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="member-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with pytest.raises(NoPermissionError, match="do not have permission"):
|
|
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
|
|
|
def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
|
|
user = DatasetServiceUnitDataFactory.create_user_mock(
|
|
user_id="member-1",
|
|
role=TenantAccountRole.EDITOR,
|
|
)
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.filter_by.return_value.all.return_value = []
|
|
|
|
with pytest.raises(NoPermissionError, match="do not have permission"):
|
|
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
|
|
|
|
def test_get_dataset_queries_delegates_to_paginate(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.desc.side_effect = lambda column: column
|
|
mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1)
|
|
|
|
items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20)
|
|
|
|
assert items == ["query"]
|
|
assert total == 1
|
|
mock_db.paginate.assert_called_once()
|
|
|
|
def test_get_related_apps_returns_ordered_query_results(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.desc.side_effect = lambda column: column
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
|
|
"relation-1"
|
|
]
|
|
|
|
result = DatasetService.get_related_apps("dataset-1")
|
|
|
|
assert result == ["relation-1"]
|
|
|
|
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self):
|
|
with patch.object(DatasetService, "get_dataset", return_value=None):
|
|
with pytest.raises(NotFound, match="Dataset not found"):
|
|
DatasetService.update_dataset_api_status("dataset-1", True)
|
|
|
|
def test_update_dataset_api_status_requires_current_user_id(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False)
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
|
|
):
|
|
with pytest.raises(ValueError, match="Current user or current user id not found"):
|
|
DatasetService.update_dataset_api_status(dataset.id, True)
|
|
|
|
def test_update_dataset_api_status_updates_fields_and_commits(self):
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False)
|
|
now = object()
|
|
|
|
with (
|
|
patch.object(DatasetService, "get_dataset", return_value=dataset),
|
|
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
|
|
patch("services.dataset_service.naive_utc_now", return_value=now),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
):
|
|
DatasetService.update_dataset_api_status(dataset.id, True)
|
|
|
|
assert dataset.enable_api is True
|
|
assert dataset.updated_by == "user-1"
|
|
assert dataset.updated_at is now
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
|
|
features = SimpleNamespace(
|
|
billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL))
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.FeatureService.get_features", return_value=features),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
):
|
|
result = DatasetService.get_dataset_auto_disable_logs("dataset-1")
|
|
|
|
assert result == {"document_ids": [], "count": 0}
|
|
mock_db.session.scalars.assert_not_called()
|
|
|
|
def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self):
|
|
class FakeAccount:
|
|
pass
|
|
|
|
current_user = FakeAccount()
|
|
current_user.current_tenant_id = "tenant-1"
|
|
logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
|
|
features = SimpleNamespace(
|
|
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL))
|
|
)
|
|
|
|
with (
|
|
patch("services.dataset_service.Account", FakeAccount),
|
|
patch("services.dataset_service.current_user", current_user),
|
|
patch("services.dataset_service.FeatureService.get_features", return_value=features),
|
|
patch("services.dataset_service.db") as mock_db,
|
|
):
|
|
mock_db.session.scalars.return_value.all.return_value = logs
|
|
|
|
result = DatasetService.get_dataset_auto_disable_logs("dataset-1")
|
|
|
|
assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2}
|
|
|
|
|
|
class TestDatasetServiceDocumentIndexing:
|
|
"""Unit tests for pause/recover/retry orchestration without SQL assertions."""
|
|
|
|
@pytest.fixture
|
|
def mock_document_service_dependencies(self):
|
|
with (
|
|
patch("services.dataset_service.redis_client") as mock_redis,
|
|
patch("services.dataset_service.db.session") as mock_db_session,
|
|
patch("services.dataset_service.current_user") as mock_current_user,
|
|
):
|
|
mock_current_user.id = "user-123"
|
|
yield {
|
|
"redis_client": mock_redis,
|
|
"db_session": mock_db_session,
|
|
"current_user": mock_current_user,
|
|
}
|
|
|
|
def test_pause_document_success(self, mock_document_service_dependencies):
|
|
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing")
|
|
|
|
DocumentService.pause_document(document)
|
|
|
|
assert document.is_paused is True
|
|
assert document.paused_by == "user-123"
|
|
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
|
|
mock_document_service_dependencies["db_session"].commit.assert_called_once()
|
|
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(
|
|
f"document_{document.id}_is_paused",
|
|
"True",
|
|
)
|
|
|
|
def test_pause_document_invalid_status_error(self, mock_document_service_dependencies):
|
|
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed")
|
|
|
|
with pytest.raises(DocumentIndexingError):
|
|
DocumentService.pause_document(document)
|
|
|
|
def test_recover_document_success(self, mock_document_service_dependencies):
|
|
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True)
|
|
|
|
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
|
|
DocumentService.recover_document(document)
|
|
|
|
assert document.is_paused is False
|
|
assert document.paused_by is None
|
|
assert document.paused_at is None
|
|
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
|
|
mock_document_service_dependencies["db_session"].commit.assert_called_once()
|
|
mock_document_service_dependencies["redis_client"].delete.assert_called_once_with(
|
|
f"document_{document.id}_is_paused"
|
|
)
|
|
recover_task.delay.assert_called_once_with(document.dataset_id, document.id)
|
|
|
|
def test_retry_document_indexing_success(self, mock_document_service_dependencies):
|
|
dataset_id = "dataset-123"
|
|
documents = [
|
|
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"),
|
|
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"),
|
|
]
|
|
mock_document_service_dependencies["redis_client"].get.return_value = None
|
|
|
|
with patch("services.dataset_service.retry_document_indexing_task") as retry_task:
|
|
DocumentService.retry_document(dataset_id, documents)
|
|
|
|
assert all(document.indexing_status == "waiting" for document in documents)
|
|
assert mock_document_service_dependencies["db_session"].add.call_count == 2
|
|
assert mock_document_service_dependencies["db_session"].commit.call_count == 2
|
|
assert mock_document_service_dependencies["redis_client"].setex.call_count == 2
|
|
retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123")
|
|
|
|
|
|
class TestDatasetCollectionBindingService:
|
|
"""Unit tests for dataset collection binding lookups and creation."""
|
|
|
|
def test_get_dataset_collection_binding_returns_existing_binding(self):
|
|
binding = SimpleNamespace(id="binding-1")
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
|
|
|
|
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
|
|
|
|
assert result is binding
|
|
mock_db.session.add.assert_not_called()
|
|
|
|
def test_get_dataset_collection_binding_creates_binding_when_missing(self):
|
|
created_binding = SimpleNamespace(id="binding-2")
|
|
|
|
with (
|
|
patch("services.dataset_service.db") as mock_db,
|
|
patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls,
|
|
patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"),
|
|
):
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
|
|
|
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset")
|
|
|
|
assert result is created_binding
|
|
binding_cls.assert_called_once_with(
|
|
provider_name="provider",
|
|
model_name="model",
|
|
collection_name="generated-collection",
|
|
type="dataset",
|
|
)
|
|
mock_db.session.add.assert_called_once_with(created_binding)
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
|
|
|
with pytest.raises(ValueError, match="Dataset collection binding not found"):
|
|
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
|
|
|
|
def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self):
|
|
binding = SimpleNamespace(id="binding-1")
|
|
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
|
|
|
|
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
|
|
|
|
assert result is binding
|
|
|
|
|
|
class TestDatasetPermissionService:
|
|
"""Unit tests for dataset partial-member management helpers."""
|
|
|
|
def test_get_dataset_partial_member_list_returns_scalar_results(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"]
|
|
|
|
result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1")
|
|
|
|
assert result == ["user-1", "user-2"]
|
|
|
|
def test_update_partial_member_list_replaces_permissions_and_commits(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
DatasetPermissionService.update_partial_member_list(
|
|
"tenant-1",
|
|
"dataset-1",
|
|
[{"user_id": "user-1"}, {"user_id": "user-2"}],
|
|
)
|
|
|
|
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
|
|
mock_db.session.add_all.assert_called_once()
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_update_partial_member_list_rolls_back_on_exception(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.add_all.side_effect = RuntimeError("boom")
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
DatasetPermissionService.update_partial_member_list(
|
|
"tenant-1",
|
|
"dataset-1",
|
|
[{"user_id": "user-1"}],
|
|
)
|
|
|
|
mock_db.session.rollback.assert_called_once()
|
|
|
|
def test_check_permission_requires_dataset_editor(self):
|
|
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
|
|
|
|
with pytest.raises(NoPermissionError, match="does not have permission"):
|
|
DatasetPermissionService.check_permission(user, dataset, "all_team", [])
|
|
|
|
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
|
|
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team")
|
|
|
|
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
|
|
DatasetPermissionService.check_permission(user, dataset, "only_me", [])
|
|
|
|
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
|
|
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members")
|
|
|
|
with pytest.raises(ValueError, match="Partial member list is required"):
|
|
DatasetPermissionService.check_permission(user, dataset, "partial_members", [])
|
|
|
|
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
|
|
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
dataset_id="dataset-1", permission="partial_members"
|
|
)
|
|
|
|
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
|
|
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
|
|
DatasetPermissionService.check_permission(
|
|
user,
|
|
dataset,
|
|
"partial_members",
|
|
[{"user_id": "user-2"}],
|
|
)
|
|
|
|
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
|
|
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
|
|
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
|
|
dataset_id="dataset-1", permission="partial_members"
|
|
)
|
|
|
|
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
|
|
DatasetPermissionService.check_permission(
|
|
user,
|
|
dataset,
|
|
"partial_members",
|
|
[{"user_id": "user-1"}],
|
|
)
|
|
|
|
def test_clear_partial_member_list_deletes_permissions_and_commits(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
DatasetPermissionService.clear_partial_member_list("dataset-1")
|
|
|
|
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
|
|
mock_db.session.commit.assert_called_once()
|
|
|
|
def test_clear_partial_member_list_rolls_back_on_exception(self):
|
|
with patch("services.dataset_service.db") as mock_db:
|
|
mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom")
|
|
|
|
with pytest.raises(RuntimeError, match="boom"):
|
|
DatasetPermissionService.clear_partial_member_list("dataset-1")
|
|
|
|
mock_db.session.rollback.assert_called_once()
|