Files
dify/api/tests/unit_tests/services/test_dataset_service_dataset.py

1753 lines
80 KiB
Python

"""Unit tests for DatasetService and dataset-related collaborators."""
from .dataset_service_test_helpers import (
CloudPlan,
Dataset,
DatasetCollectionBindingService,
DatasetNameDuplicateError,
DatasetPermissionEnum,
DatasetPermissionService,
DatasetProcessRule,
DatasetService,
DatasetServiceUnitDataFactory,
DocumentIndexingError,
DocumentService,
LLMBadRequestError,
MagicMock,
Mock,
ModelFeature,
ModelType,
NoPermissionError,
NotFound,
PipelineIconInfo,
ProviderTokenNotInitError,
RagPipelineDatasetCreateEntity,
SimpleNamespace,
TenantAccountRole,
_make_knowledge_configuration,
_make_retrieval_model,
_make_session_context,
json,
patch,
pytest,
)
class TestDatasetServiceQueries:
"""Unit tests for DatasetService query composition and fallback branches."""
@pytest.fixture
def mock_dataset_query_dependencies(self):
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like,
patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids,
):
mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1)
yield {
"db": mock_db,
"escape_like_pattern": escape_like,
"get_target_ids": get_target_ids,
}
def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies):
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1")
assert items == ["dataset"]
assert total == 1
mock_dataset_query_dependencies["db"].paginate.assert_called_once()
mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called()
def test_get_datasets_short_circuits_for_dataset_operator_without_permissions(
self, mock_dataset_query_dependencies
):
user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR)
mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = []
items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user)
assert items == []
assert total == 0
mock_dataset_query_dependencies["db"].paginate.assert_not_called()
def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies):
mock_dataset_query_dependencies["get_target_ids"].return_value = []
items, total = DatasetService.get_datasets(
page=1,
per_page=20,
tenant_id="tenant-1",
tag_ids=["tag-1"],
)
assert items == []
assert total == 0
mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"])
mock_dataset_query_dependencies["db"].paginate.assert_not_called()
def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies):
mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"]
items, total = DatasetService.get_datasets(
page=2,
per_page=10,
tenant_id="tenant-1",
search="report",
tag_ids=["tag-1"],
)
assert items == ["dataset"]
assert total == 1
mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report")
mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"])
mock_dataset_query_dependencies["db"].paginate.assert_called_once()
def test_get_process_rules_returns_latest_rule_when_present(self):
dataset_process_rule = Mock(spec=DatasetProcessRule)
dataset_process_rule.mode = "automatic"
dataset_process_rule.rules_dict = {"delimiter": "\n"}
with patch("services.dataset_service.db") as mock_db:
(
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
) = dataset_process_rule
result = DatasetService.get_process_rules("dataset-1")
assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}}
def test_get_process_rules_falls_back_to_default_rules_when_missing(self):
with patch("services.dataset_service.db") as mock_db:
(
mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value
) = None
result = DatasetService.get_process_rules("dataset-1")
assert result == {
"mode": DocumentService.DEFAULT_RULES["mode"],
"rules": DocumentService.DEFAULT_RULES["rules"],
}
def test_get_datasets_by_ids_returns_empty_for_missing_ids(self):
with patch("services.dataset_service.db") as mock_db:
items, total = DatasetService.get_datasets_by_ids([], "tenant-1")
assert items == []
assert total == 0
mock_db.paginate.assert_not_called()
def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1)
items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1")
assert items == ["dataset-1"]
assert total == 1
mock_db.paginate.assert_called_once()
def test_get_dataset_returns_first_match(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset
result = DatasetService.get_dataset(dataset.id)
assert result is dataset
class TestDatasetServiceValidation:
"""Unit tests for DatasetService validation helpers."""
@pytest.mark.parametrize(
("dataset_doc_form", "incoming_doc_form"),
[(None, "text_model"), ("text_model", "text_model")],
)
def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form)
DatasetService.check_doc_form(dataset, incoming_doc_form)
def test_check_doc_form_rejects_mismatched_doc_form(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model")
with pytest.raises(ValueError, match="doc_form is different"):
DatasetService.check_doc_form(dataset, "text_model")
def test_check_dataset_model_setting_skips_non_high_quality_datasets(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.assert_not_called()
def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
def test_check_dataset_model_setting_wraps_llm_bad_request_error(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService.check_dataset_model_setting(dataset)
def test_check_dataset_model_setting_wraps_provider_token_error(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
with pytest.raises(ValueError, match="token missing"):
DatasetService.check_dataset_model_setting(dataset)
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup")
with pytest.raises(ValueError, match="provider setup"):
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
def test_check_reranking_model_setting_uses_rerank_model_type(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="provider",
model_type=ModelType.RERANK,
model="reranker",
)
def test_check_reranking_model_setting_wraps_bad_request(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Rerank Model available"):
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self):
model_schema = SimpleNamespace(features=[ModelFeature.VISION])
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = model_schema
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
assert result is True
def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self):
model_schema = SimpleNamespace(features=[])
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = model_schema
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
assert result is False
def test_check_is_multimodal_model_raises_when_schema_is_missing(self):
model_type_instance = MagicMock()
model_type_instance.get_model_schema.return_value = None
model_instance = SimpleNamespace(
model_type_instance=model_type_instance,
model_name="embedding-model",
credentials={"api_key": "secret"},
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
with pytest.raises(ValueError, match="Model schema not found"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
def test_check_is_multimodal_model_wraps_bad_request_error(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Model available"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
class TestDatasetServiceCreationAndUpdate:
"""Unit tests for dataset creation and update helpers."""
def test_create_empty_dataset_raises_when_name_already_exists(self):
account = SimpleNamespace(id="user-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"):
DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account)
def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self):
account = SimpleNamespace(id="user-1")
default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding")
with (
patch("services.dataset_service.db") as mock_db,
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
name="Dataset",
description="Description",
indexing_technique="high_quality",
account=account,
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "default-embedding"
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
assert dataset.provider == "vendor"
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.TEXT_EMBEDDING,
)
check_embedding.assert_not_called()
mock_db.session.commit.assert_called_once()
def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self):
account = SimpleNamespace(id="user-1")
retrieval_model = _make_retrieval_model()
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.db") as mock_db,
patch(
"services.dataset_service.Dataset",
side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs),
),
patch(
"services.dataset_service.ExternalKnowledgeBindings",
side_effect=lambda **kwargs: SimpleNamespace(**kwargs),
) as binding_cls,
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()),
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
name="External Dataset",
description="Description",
indexing_technique="high_quality",
account=account,
permission=DatasetPermissionEnum.ALL_TEAM,
provider="external",
external_knowledge_api_id="api-1",
external_knowledge_id="knowledge-1",
embedding_model_provider="provider",
embedding_model_name="embedding-model",
retrieval_model=retrieval_model,
summary_index_setting={"enable": True},
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "embedding-model"
assert dataset.retrieval_model == retrieval_model.model_dump()
assert dataset.summary_index_setting == {"enable": True}
check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model")
check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model")
binding_cls.assert_called_once_with(
tenant_id="tenant-1",
dataset_id="dataset-1",
external_knowledge_api_id="api-1",
external_knowledge_id="knowledge-1",
created_by="user-1",
)
assert mock_db.session.add.call_count == 2
mock_db.session.commit.assert_called_once()
def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self):
entity = RagPipelineDatasetCreateEntity(
name="Existing Dataset",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self):
entity = RagPipelineDatasetCreateEntity(
name="",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
pipeline = SimpleNamespace(id="pipeline-1")
def pipeline_factory(**kwargs):
pipeline.__dict__.update(kwargs)
return pipeline
def dataset_factory(**kwargs):
return SimpleNamespace(id="dataset-1", **kwargs)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name,
patch("services.dataset_service.Pipeline", side_effect=pipeline_factory),
patch("services.dataset_service.Dataset", side_effect=dataset_factory),
):
mock_db.session.query.return_value.filter_by.return_value.all.return_value = [
SimpleNamespace(name="Untitled"),
SimpleNamespace(name="Untitled 1"),
]
dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
assert entity.name == "Untitled 2"
assert dataset.pipeline_id == "pipeline-1"
assert dataset.runtime_mode == "rag_pipeline"
generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled")
mock_db.session.commit.assert_called_once()
def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self):
entity = RagPipelineDatasetCreateEntity(
name="Dataset",
description="Description",
icon_info=PipelineIconInfo(icon="book", icon_background="#fff"),
permission=DatasetPermissionEnum.ALL_TEAM,
)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity)
def test_update_dataset_raises_when_dataset_is_missing(self):
with patch.object(DatasetService, "get_dataset", return_value=None):
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1"))
def test_update_dataset_raises_when_new_name_conflicts(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.name = "Old Dataset"
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "_has_dataset_same_name", return_value=True),
):
with pytest.raises(ValueError, match="Dataset name already exists"):
DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1"))
def test_update_dataset_routes_external_datasets_to_external_helper(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.provider = "external"
user = DatasetServiceUnitDataFactory.create_user_mock()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission") as check_permission,
patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external,
):
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
assert result == "updated"
check_permission.assert_called_once_with(dataset, user)
update_external.assert_called_once_with(dataset, {"name": dataset.name}, user)
def test_update_dataset_routes_internal_datasets_to_internal_helper(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1")
dataset.provider = "vendor"
user = DatasetServiceUnitDataFactory.create_user_mock()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission") as check_permission,
patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal,
):
result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user)
assert result == "updated"
check_permission.assert_called_once_with(dataset, user)
update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user)
def test_has_dataset_same_name_returns_true_when_query_matches(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = object()
result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset")
assert result is True
def test_update_external_dataset_updates_dataset_and_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
user = SimpleNamespace(id="user-1")
now = object()
with (
patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding,
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
):
result = DatasetService._update_external_dataset(
dataset,
{
"external_retrieval_model": {"top_k": 3},
"summary_index_setting": {"enable": True},
"name": "Updated Dataset",
"description": "Updated description",
"permission": DatasetPermissionEnum.PARTIAL_TEAM,
"external_knowledge_id": "knowledge-1",
"external_knowledge_api_id": "api-1",
},
user,
)
assert result is dataset
assert dataset.retrieval_model == {"top_k": 3}
assert dataset.summary_index_setting == {"enable": True}
assert dataset.name == "Updated Dataset"
assert dataset.description == "Updated description"
assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM
assert dataset.updated_by == "user-1"
assert dataset.updated_at is now
update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1")
mock_db.session.add.assert_called_once_with(dataset)
mock_db.session.commit.assert_called_once()
@pytest.mark.parametrize(
("payload", "message"),
[
({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"),
({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"),
],
)
def test_update_external_dataset_requires_external_binding_fields(self, payload, message):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
with pytest.raises(ValueError, match=message):
DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1"))
def test_update_external_knowledge_binding_updates_changed_binding_values(self):
binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api")
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = binding
session_context = _make_session_context(session)
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.Session", return_value=session_context),
):
DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api")
assert binding.external_knowledge_id == "new-knowledge"
assert binding.external_knowledge_api_id == "new-api"
mock_db.session.add.assert_called_once_with(binding)
def test_update_external_knowledge_binding_raises_for_missing_binding(self):
session = MagicMock()
session.query.return_value.filter_by.return_value.first.return_value = None
session_context = _make_session_context(session)
with (
patch("services.dataset_service.db"),
patch("services.dataset_service.Session", return_value=session_context),
):
with pytest.raises(ValueError, match="External knowledge binding not found"):
DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1")
def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
user = SimpleNamespace(id="user-1")
now = object()
update_payload = {
"name": "Updated Dataset",
"description": None,
"partial_member_list": [{"user_id": "member-1"}],
"external_knowledge_api_id": "api-1",
"external_knowledge_id": "knowledge-1",
"external_retrieval_model": {"top_k": 2},
"retrieval_model": {"top_k": 4},
"summary_index_setting": {"enable": True},
"icon_info": {"icon": "book"},
}
with (
patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"),
patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline,
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task,
patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task,
):
result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user)
assert result is dataset
updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0]
assert updated_values["name"] == "Updated Dataset"
assert updated_values["description"] is None
assert updated_values["retrieval_model"] == {"top_k": 4}
assert updated_values["summary_index_setting"] == {"enable": True}
assert updated_values["icon_info"] == {"icon": "book"}
assert updated_values["updated_by"] == "user-1"
assert updated_values["updated_at"] is now
assert "partial_member_list" not in updated_values
assert "external_knowledge_api_id" not in updated_values
assert "external_knowledge_id" not in updated_values
assert "external_retrieval_model" not in updated_values
mock_db.session.commit.assert_called_once()
mock_db.session.refresh.assert_called_once_with(dataset)
update_pipeline.assert_called_once_with(dataset, "user-1")
vector_task.delay.assert_called_once_with("dataset-1", "update")
regenerate_task.delay.assert_called_once_with(
"dataset-1",
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self):
dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1")
with patch("services.dataset_service.db") as mock_db:
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.query.assert_not_called()
def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self):
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.commit.assert_not_called()
def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self):
dataset = SimpleNamespace(
id="dataset-1",
runtime_mode="rag_pipeline",
pipeline_id="pipeline-1",
embedding_model="embedding-model",
embedding_model_provider="provider",
retrieval_model={"top_k": 5},
chunk_structure="paragraph",
indexing_technique="high_quality",
keyword_number=8,
summary_index_setting={"enable": True},
)
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
published_workflow = SimpleNamespace(
graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}),
type="chat",
features={"feature": True},
environment_variables=[],
conversation_variables=[],
rag_pipeline_variables=[],
)
draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]}))
new_workflow = SimpleNamespace(id="workflow-1")
rag_pipeline_service = MagicMock()
rag_pipeline_service.get_published_workflow.return_value = published_workflow
rag_pipeline_service.get_draft_workflow.return_value = draft_workflow
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
published_graph = json.loads(workflow_new.call_args.kwargs["graph"])
assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model"
assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True}
assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider"
mock_db.session.add.assert_any_call(new_workflow)
mock_db.session.add.assert_any_call(draft_workflow)
mock_db.session.commit.assert_called_once()
def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self):
dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1")
pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1")
rag_pipeline_service = MagicMock()
rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom")
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service),
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline
with pytest.raises(RuntimeError, match="boom"):
DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1")
mock_db.session.rollback.assert_called_once()
def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="economy")
result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data)
assert result is None
assert filtered_data == {}
def test_handle_indexing_technique_change_switches_to_economy(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="high_quality")
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "economy"},
filtered_data,
)
assert result == "remove"
assert filtered_data == {
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
}
def test_handle_indexing_technique_change_switches_to_high_quality(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="economy")
with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding:
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "high_quality"},
filtered_data,
)
assert result == "add"
configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data)
def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self):
filtered_data: dict[str, object] = {}
dataset = SimpleNamespace(indexing_technique="high_quality")
with patch.object(
DatasetService,
"_handle_embedding_model_update_when_technique_unchanged",
return_value="update",
) as update_embedding:
result = DatasetService._handle_indexing_technique_change(
dataset,
{"indexing_technique": "high_quality"},
filtered_data,
)
assert result == "update"
update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data)
def test_configure_embedding_model_for_high_quality_updates_filtered_data(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
filtered_data,
)
assert filtered_data == {
"embedding_model": "embedding-model",
"embedding_model_provider": "provider",
"collection_binding_id": "binding-1",
}
@pytest.mark.parametrize(
("error", "message"),
[
(LLMBadRequestError(), "No Embedding Model available"),
(ProviderTokenNotInitError("provider setup"), "provider setup"),
],
)
def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.return_value.get_model_instance.side_effect = error
with pytest.raises(ValueError, match=message):
DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
{},
)
def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
filtered_data: dict[str, object] = {}
with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings:
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
dataset,
{},
filtered_data,
)
assert result is None
preserve_settings.assert_called_once_with(dataset, filtered_data)
def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings:
result = DatasetService._handle_embedding_model_update_when_technique_unchanged(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
assert result == "update"
update_settings.assert_called_once()
def test_preserve_existing_embedding_settings_keeps_current_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
collection_binding_id="binding-1",
)
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
assert filtered_data == {
"embedding_model_provider": "provider",
"embedding_model": "embedding-model",
"collection_binding_id": "binding-1",
}
def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider=None,
embedding_model=None,
collection_binding_id=None,
)
filtered_data = {"embedding_model_provider": "", "embedding_model": ""}
DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
assert filtered_data == {}
def test_update_embedding_model_settings_returns_update_for_changed_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings:
result = DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
assert result == "update"
apply_settings.assert_called_once()
def test_update_embedding_model_settings_returns_none_for_unchanged_values(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
result = DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
{},
)
assert result is None
def test_update_embedding_model_settings_wraps_bad_request_errors(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
)
with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()):
with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService._update_embedding_model_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
{},
)
def test_apply_new_embedding_settings_updates_binding_for_new_model(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1")
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-2"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
DatasetService._apply_new_embedding_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
filtered_data,
)
assert filtered_data == {
"embedding_model": "embedding-model-two",
"embedding_model_provider": "provider-two",
"collection_binding_id": "binding-2",
}
def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
embedding_model_provider="provider",
embedding_model="embedding-model",
collection_binding_id="binding-1",
)
filtered_data: dict[str, object] = {}
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
DatasetService._apply_new_embedding_settings(
dataset,
{"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"},
filtered_data,
)
assert filtered_data == {
"embedding_model_provider": "provider",
"embedding_model": "embedding-model",
"collection_binding_id": "binding-1",
}
@pytest.mark.parametrize(
("summary_index_setting", "expected"),
[
(None, False),
({"enable": False}, False),
({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False),
({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True),
],
)
def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1",
summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"},
)
result = DatasetService._check_summary_index_setting_model_changed(
dataset,
{"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {},
)
assert result is expected
class TestDatasetServiceRagPipelineSettings:
"""Unit tests for rag-pipeline dataset setting updates."""
def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
knowledge_configuration = _make_knowledge_configuration()
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)):
with pytest.raises(ValueError, match="Current user or current tenant not found"):
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True})
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
assert dataset.chunk_structure == "paragraph"
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "embedding-model"
assert dataset.embedding_model_provider == "provider"
assert dataset.collection_binding_id == "binding-1"
assert dataset.is_multimodal is True
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
assert dataset.summary_index_setting == {"enable": True}
check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model")
session.add.assert_called_once_with(dataset)
session.commit.assert_not_called()
def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
keyword_number=12,
)
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
assert dataset.indexing_technique == "economy"
assert dataset.keyword_number == 12
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence")
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
)
with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")):
with pytest.raises(
ValueError,
match="Knowledge base indexing technique is not allowed to be updated to economy",
):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "economy"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration()
embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model")
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=False),
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-1"),
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "embedding-model"
assert dataset.embedding_model_provider == "provider"
assert dataset.collection_binding_id == "binding-1"
assert dataset.is_multimodal is False
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "add")
def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "provider"
dataset.embedding_model = "embedding-model"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
embedding_model_provider="provider-two",
embedding_model="embedding-model-two",
summary_index_setting={"enable": True},
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch.object(DatasetService, "check_is_multimodal_model", return_value=True),
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding",
return_value=SimpleNamespace(id="binding-2"),
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.embedding_model_provider == "provider-two"
assert dataset.embedding_model == "embedding-model-two"
assert dataset.collection_binding_id == "binding-2"
assert dataset.is_multimodal is True
assert dataset.summary_index_setting == {"enable": True}
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "update")
def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "high_quality"
dataset.embedding_model_provider = "provider"
dataset.embedding_model = "embedding-model"
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
embedding_model_provider="provider-two",
embedding_model="embedding-model-two",
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.embedding_model_provider == "provider"
assert dataset.embedding_model == "embedding-model"
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_called_once_with("dataset-1", "update")
def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self):
session = MagicMock()
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1")
dataset.chunk_structure = "paragraph"
dataset.indexing_technique = "economy"
dataset.keyword_number = 5
session.merge.return_value = dataset
knowledge_configuration = _make_knowledge_configuration(
indexing_technique="economy",
embedding_model_provider="",
embedding_model="",
keyword_number=9,
)
with (
patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
DatasetService.update_rag_pipeline_dataset_settings(
session,
dataset,
knowledge_configuration,
has_published=True,
)
assert dataset.keyword_number == 9
assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump()
session.add.assert_called_once_with(dataset)
session.commit.assert_called_once()
update_task.delay.assert_not_called()
class TestDatasetServicePermissionsAndLifecycle:
"""Unit tests for dataset permissions, deletion, and metadata helpers."""
def test_delete_dataset_returns_false_when_dataset_is_missing(self):
with patch.object(DatasetService, "get_dataset", return_value=None):
result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1"))
assert result is False
def test_delete_dataset_checks_permission_and_deletes_dataset(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch.object(DatasetService, "check_dataset_permission") as check_permission,
patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal,
patch("services.dataset_service.db") as mock_db,
):
result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1"))
assert result is True
check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1"))
send_deleted_signal.assert_called_once_with(dataset)
mock_db.session.delete.assert_called_once_with(dataset)
mock_db.session.commit.assert_called_once()
def test_dataset_use_check_returns_scalar_result(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.execute.return_value.scalar_one.return_value = True
result = DatasetService.dataset_use_check("dataset-1")
assert result is True
def test_check_dataset_permission_rejects_cross_tenant_access(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a")
user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b")
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
permission=DatasetPermissionEnum.ONLY_ME,
created_by="owner-1",
)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="member-1",
role=TenantAccountRole.EDITOR,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_rejects_partial_team_user_without_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
permission=DatasetPermissionEnum.PARTIAL_TEAM,
created_by="owner-1",
)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="member-1",
role=TenantAccountRole.EDITOR,
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
permission=DatasetPermissionEnum.PARTIAL_TEAM,
created_by="creator-1",
)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="creator-1",
role=TenantAccountRole.EDITOR,
)
with patch("services.dataset_service.db") as mock_db:
DatasetService.check_dataset_permission(dataset, user)
mock_db.session.query.assert_not_called()
def test_check_dataset_permission_allows_partial_team_member_with_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
permission=DatasetPermissionEnum.PARTIAL_TEAM,
created_by="owner-1",
)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="member-1",
role=TenantAccountRole.EDITOR,
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
DatasetService.check_dataset_permission(dataset, user)
def test_check_dataset_operator_permission_validates_required_arguments(self):
with pytest.raises(ValueError, match="Dataset not found"):
DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None)
with pytest.raises(ValueError, match="User not found"):
DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1"))
def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
permission=DatasetPermissionEnum.ONLY_ME,
created_by="owner-1",
)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="member-1",
role=TenantAccountRole.EDITOR,
)
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM)
user = DatasetServiceUnitDataFactory.create_user_mock(
user_id="member-1",
role=TenantAccountRole.EDITOR,
)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.all.return_value = []
with pytest.raises(NoPermissionError, match="do not have permission"):
DatasetService.check_dataset_operator_permission(user=user, dataset=dataset)
def test_get_dataset_queries_delegates_to_paginate(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.desc.side_effect = lambda column: column
mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1)
items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20)
assert items == ["query"]
assert total == 1
mock_db.paginate.assert_called_once()
def test_get_related_apps_returns_ordered_query_results(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.desc.side_effect = lambda column: column
mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [
"relation-1"
]
result = DatasetService.get_related_apps("dataset-1")
assert result == ["relation-1"]
def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self):
with patch.object(DatasetService, "get_dataset", return_value=None):
with pytest.raises(NotFound, match="Dataset not found"):
DatasetService.update_dataset_api_status("dataset-1", True)
def test_update_dataset_api_status_requires_current_user_id(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False)
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch("services.dataset_service.current_user", SimpleNamespace(id=None)),
):
with pytest.raises(ValueError, match="Current user or current user id not found"):
DatasetService.update_dataset_api_status(dataset.id, True)
def test_update_dataset_api_status_updates_fields_and_commits(self):
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False)
now = object()
with (
patch.object(DatasetService, "get_dataset", return_value=dataset),
patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")),
patch("services.dataset_service.naive_utc_now", return_value=now),
patch("services.dataset_service.db") as mock_db,
):
DatasetService.update_dataset_api_status(dataset.id, True)
assert dataset.enable_api is True
assert dataset.updated_by == "user-1"
assert dataset.updated_at is now
mock_db.session.commit.assert_called_once()
def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
features = SimpleNamespace(
billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL))
)
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
patch("services.dataset_service.db") as mock_db,
):
result = DatasetService.get_dataset_auto_disable_logs("dataset-1")
assert result == {"document_ids": [], "count": 0}
mock_db.session.scalars.assert_not_called()
def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self):
class FakeAccount:
pass
current_user = FakeAccount()
current_user.current_tenant_id = "tenant-1"
logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
features = SimpleNamespace(
billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL))
)
with (
patch("services.dataset_service.Account", FakeAccount),
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.FeatureService.get_features", return_value=features),
patch("services.dataset_service.db") as mock_db,
):
mock_db.session.scalars.return_value.all.return_value = logs
result = DatasetService.get_dataset_auto_disable_logs("dataset-1")
assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2}
class TestDatasetServiceDocumentIndexing:
"""Unit tests for pause/recover/retry orchestration without SQL assertions."""
@pytest.fixture
def mock_document_service_dependencies(self):
with (
patch("services.dataset_service.redis_client") as mock_redis,
patch("services.dataset_service.db.session") as mock_db_session,
patch("services.dataset_service.current_user") as mock_current_user,
):
mock_current_user.id = "user-123"
yield {
"redis_client": mock_redis,
"db_session": mock_db_session,
"current_user": mock_current_user,
}
def test_pause_document_success(self, mock_document_service_dependencies):
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing")
DocumentService.pause_document(document)
assert document.is_paused is True
assert document.paused_by == "user-123"
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
mock_document_service_dependencies["db_session"].commit.assert_called_once()
mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with(
f"document_{document.id}_is_paused",
"True",
)
def test_pause_document_invalid_status_error(self, mock_document_service_dependencies):
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed")
with pytest.raises(DocumentIndexingError):
DocumentService.pause_document(document)
def test_recover_document_success(self, mock_document_service_dependencies):
document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True)
with patch("services.dataset_service.recover_document_indexing_task") as recover_task:
DocumentService.recover_document(document)
assert document.is_paused is False
assert document.paused_by is None
assert document.paused_at is None
mock_document_service_dependencies["db_session"].add.assert_called_once_with(document)
mock_document_service_dependencies["db_session"].commit.assert_called_once()
mock_document_service_dependencies["redis_client"].delete.assert_called_once_with(
f"document_{document.id}_is_paused"
)
recover_task.delay.assert_called_once_with(document.dataset_id, document.id)
def test_retry_document_indexing_success(self, mock_document_service_dependencies):
dataset_id = "dataset-123"
documents = [
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"),
DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"),
]
mock_document_service_dependencies["redis_client"].get.return_value = None
with patch("services.dataset_service.retry_document_indexing_task") as retry_task:
DocumentService.retry_document(dataset_id, documents)
assert all(document.indexing_status == "waiting" for document in documents)
assert mock_document_service_dependencies["db_session"].add.call_count == 2
assert mock_document_service_dependencies["db_session"].commit.call_count == 2
assert mock_document_service_dependencies["redis_client"].setex.call_count == 2
retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123")
class TestDatasetCollectionBindingService:
"""Unit tests for dataset collection binding lookups and creation."""
def test_get_dataset_collection_binding_returns_existing_binding(self):
binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model")
assert result is binding
mock_db.session.add.assert_not_called()
def test_get_dataset_collection_binding_creates_binding_when_missing(self):
created_binding = SimpleNamespace(id="binding-2")
with (
patch("services.dataset_service.db") as mock_db,
patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls,
patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"),
):
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset")
assert result is created_binding
binding_cls.assert_called_once_with(
provider_name="provider",
model_name="model",
collection_name="generated-collection",
type="dataset",
)
mock_db.session.add.assert_called_once_with(created_binding)
mock_db.session.commit.assert_called_once()
def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None
with pytest.raises(ValueError, match="Dataset collection binding not found"):
DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self):
binding = SimpleNamespace(id="binding-1")
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding
result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1")
assert result is binding
class TestDatasetPermissionService:
"""Unit tests for dataset partial-member management helpers."""
def test_get_dataset_partial_member_list_returns_scalar_results(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"]
result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1")
assert result == ["user-1", "user-2"]
def test_update_partial_member_list_replaces_permissions_and_commits(self):
with patch("services.dataset_service.db") as mock_db:
DatasetPermissionService.update_partial_member_list(
"tenant-1",
"dataset-1",
[{"user_id": "user-1"}, {"user_id": "user-2"}],
)
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
mock_db.session.add_all.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_update_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.add_all.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.update_partial_member_list(
"tenant-1",
"dataset-1",
[{"user_id": "user-1"}],
)
mock_db.session.rollback.assert_called_once()
def test_check_permission_requires_dataset_editor(self):
user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock()
with pytest.raises(NoPermissionError, match="does not have permission"):
DatasetPermissionService.check_permission(user, dataset, "all_team", [])
def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team")
with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(user, dataset, "only_me", [])
def test_check_permission_requires_partial_member_list_for_partial_members_mode(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members")
with pytest.raises(ValueError, match="Partial member list is required"):
DatasetPermissionService.check_permission(user, dataset, "partial_members", [])
def test_check_permission_rejects_dataset_operator_member_list_changes(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1", permission="partial_members"
)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
with pytest.raises(ValueError, match="cannot change the dataset permissions"):
DatasetPermissionService.check_permission(
user,
dataset,
"partial_members",
[{"user_id": "user-2"}],
)
def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self):
user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True)
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(
dataset_id="dataset-1", permission="partial_members"
)
with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]):
DatasetPermissionService.check_permission(
user,
dataset,
"partial_members",
[{"user_id": "user-1"}],
)
def test_clear_partial_member_list_deletes_permissions_and_commits(self):
with patch("services.dataset_service.db") as mock_db:
DatasetPermissionService.clear_partial_member_list("dataset-1")
mock_db.session.query.return_value.where.return_value.delete.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_clear_partial_member_list_rolls_back_on_exception(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom")
with pytest.raises(RuntimeError, match="boom"):
DatasetPermissionService.clear_partial_member_list("dataset-1")
mock_db.session.rollback.assert_called_once()