From e08c06cbc3b7cb583a184a647d9528ae9f6c68cd Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 26 Mar 2026 16:13:53 +0800 Subject: [PATCH] fix: import path (#34124) Co-authored-by: -LAN- --- .../services/dataset_service_test_helpers.py | 2 +- .../services/test_dataset_service_dataset.py | 52 +++++++++++-------- .../services/test_dataset_service_document.py | 4 +- .../services/test_dataset_service_segment.py | 17 +++--- 4 files changed, 42 insertions(+), 33 deletions(-) diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py index 542179e2a3..c95b60fad0 100644 --- a/api/tests/unit_tests/services/dataset_service_test_helpers.py +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -10,7 +10,6 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType from werkzeug.exceptions import Forbidden, NotFound from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -18,6 +17,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType from core.rag.retrieval.retrieval_methods import RetrievalMethod from enums.cloud_plan import CloudPlan +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType from models import Account, TenantAccountRole from models.dataset import ( ChildChunk, diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py index 771240e7bf..92aed7c30a 100644 --- a/api/tests/unit_tests/services/test_dataset_service_dataset.py +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -190,7 +190,7 @@ class TestDatasetServiceValidation: with patch("services.dataset_service.ModelManager") as model_manager_cls: DatasetService.check_dataset_model_setting(dataset) - model_manager_cls.return_value.get_model_instance.assert_called_once_with( + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -201,7 +201,7 @@ class TestDatasetServiceValidation: dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() with pytest.raises(ValueError, match="No Embedding Model available"): DatasetService.check_dataset_model_setting(dataset) @@ -210,14 +210,18 @@ class TestDatasetServiceValidation: dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) - with pytest.raises(ValueError, match="token missing"): + with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"): DatasetService.check_dataset_model_setting(dataset) def test_check_embedding_model_setting_wraps_provider_token_error_description(self): with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup") + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "provider setup" + ) with pytest.raises(ValueError, match="provider setup"): DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model") @@ -226,7 +230,7 @@ class TestDatasetServiceValidation: with patch("services.dataset_service.ModelManager") as model_manager_cls: DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") - model_manager_cls.return_value.get_model_instance.assert_called_once_with( + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( tenant_id="tenant-1", provider="provider", model_type=ModelType.RERANK, @@ -235,7 +239,7 @@ class TestDatasetServiceValidation: def test_check_reranking_model_setting_wraps_bad_request(self): with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() with pytest.raises(ValueError, match="No Rerank Model available"): DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") @@ -251,7 +255,7 @@ class TestDatasetServiceValidation: ) with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.return_value = model_instance + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") @@ -268,7 +272,7 @@ class TestDatasetServiceValidation: ) with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.return_value = model_instance + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") @@ -284,14 +288,14 @@ class TestDatasetServiceValidation: ) with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.return_value = model_instance + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance with pytest.raises(ValueError, match="Model schema not found"): DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") def test_check_is_multimodal_model_wraps_bad_request_error(self): with patch("services.dataset_service.ModelManager") as model_manager_cls: - model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError() + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() with pytest.raises(ValueError, match="No Model available"): DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") @@ -323,7 +327,7 @@ class TestDatasetServiceCreationAndUpdate: patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, ): mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model dataset = DatasetService.create_empty_dataset( tenant_id="tenant-1", @@ -337,7 +341,7 @@ class TestDatasetServiceCreationAndUpdate: assert dataset.embedding_model == "default-embedding" assert dataset.permission == DatasetPermissionEnum.ONLY_ME assert dataset.provider == "vendor" - model_manager_cls.return_value.get_default_model_instance.assert_called_once_with( + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( tenant_id="tenant-1", model_type=ModelType.TEXT_EMBEDDING, ) @@ -365,7 +369,7 @@ class TestDatasetServiceCreationAndUpdate: patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, ): mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model dataset = DatasetService.create_empty_dataset( tenant_id="tenant-1", @@ -804,7 +808,7 @@ class TestDatasetServiceCreationAndUpdate: return_value=SimpleNamespace(id="binding-1"), ), ): - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model DatasetService._configure_embedding_model_for_high_quality( {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, @@ -836,7 +840,7 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.current_user", current_user), patch("services.dataset_service.ModelManager") as model_manager_cls, ): - model_manager_cls.return_value.get_model_instance.side_effect = error + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error with pytest.raises(ValueError, match=message): DatasetService._configure_embedding_model_for_high_quality( @@ -967,7 +971,7 @@ class TestDatasetServiceCreationAndUpdate: return_value=SimpleNamespace(id="binding-2"), ), ): - model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace( + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( provider="provider-two", model_name="embedding-model-two", ) @@ -1002,7 +1006,9 @@ class TestDatasetServiceCreationAndUpdate: patch("services.dataset_service.current_user", current_user), patch("services.dataset_service.ModelManager") as model_manager_cls, ): - model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) DatasetService._apply_new_embedding_settings( dataset, @@ -1067,7 +1073,7 @@ class TestDatasetServiceRagPipelineSettings: return_value=SimpleNamespace(id="binding-1"), ), ): - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) @@ -1161,7 +1167,7 @@ class TestDatasetServiceRagPipelineSettings: ), patch("services.dataset_service.deal_dataset_index_update_task") as update_task, ): - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model DatasetService.update_rag_pipeline_dataset_settings( session, @@ -1204,7 +1210,7 @@ class TestDatasetServiceRagPipelineSettings: ), patch("services.dataset_service.deal_dataset_index_update_task") as update_task, ): - model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace( + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( provider="provider-two", model_name="embedding-model-two", ) @@ -1243,7 +1249,9 @@ class TestDatasetServiceRagPipelineSettings: patch("services.dataset_service.ModelManager") as model_manager_cls, patch("services.dataset_service.deal_dataset_index_update_task") as update_task, ): - model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing") + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) DatasetService.update_rag_pipeline_dataset_settings( session, diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py index 8cb3dcae4a..c8036487ab 100644 --- a/api/tests/unit_tests/services/test_dataset_service_document.py +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -1828,7 +1828,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: ) as get_binding, patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document), ): - model_manager_cls.return_value.get_default_model_instance.return_value = SimpleNamespace( + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace( model_name="default-embedding", provider="default-provider", ) @@ -1880,7 +1880,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches: ): DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) - model_manager_cls.return_value.get_default_model_instance.assert_not_called() + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called() get_binding.assert_called_once_with("explicit-provider", "explicit-model") assert dataset.embedding_model == "explicit-model" assert dataset.embedding_model_provider == "explicit-provider" diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py index f3933448f7..2f8ae14a8e 100644 --- a/api/tests/unit_tests/services/test_dataset_service_segment.py +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -9,6 +9,7 @@ from .dataset_service_test_helpers import ( DocumentSegment, IndexStructureType, MagicMock, + ModelType, SegmentService, SegmentUpdateArgs, SimpleNamespace, @@ -459,7 +460,7 @@ class TestSegmentServiceMutations: patch("services.dataset_service.naive_utc_now", return_value="now"), ): mock_redis.lock.return_value = _make_lock_context() - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") @@ -571,7 +572,7 @@ class TestSegmentServiceMutations: patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, ): mock_redis.get.return_value = None - model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance processing_rule_query = MagicMock() processing_rule_query.where.return_value.first.return_value = processing_rule @@ -618,7 +619,7 @@ class TestSegmentServiceMutations: ) as generate_summary, ): mock_redis.get.return_value = None - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model summary_query = MagicMock() summary_query.where.return_value.first.return_value = existing_summary @@ -661,7 +662,7 @@ class TestSegmentServiceMutations: patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, ): mock_redis.get.return_value = None - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model summary_query = MagicMock() summary_query.where.return_value.first.return_value = existing_summary @@ -900,7 +901,7 @@ class TestSegmentServiceAdditionalRegenerationBranches: patch("services.dataset_service.naive_utc_now", return_value="now"), ): mock_redis.get.return_value = None - model_manager_cls.return_value.get_model_instance.return_value = embedding_model + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model summary_query = MagicMock() summary_query.where.return_value.first.return_value = None refreshed_query = MagicMock() @@ -947,7 +948,7 @@ class TestSegmentServiceAdditionalRegenerationBranches: patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, ): mock_redis.get.return_value = None - model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance update_summary.side_effect = RuntimeError("summary failed") processing_rule_query = MagicMock() @@ -966,9 +967,9 @@ class TestSegmentServiceAdditionalRegenerationBranches: ) assert result is refreshed_segment - model_manager_cls.return_value.get_default_model_instance.assert_called_once_with( + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( tenant_id="tenant-1", - model_type="text-embedding", + model_type=ModelType.TEXT_EMBEDDING, ) vector_service.generate_child_chunks.assert_called_once_with( segment,