fix: import path (#34124)

Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
QuantumGhost
2026-03-26 16:13:53 +08:00
committed by GitHub
parent 8ca54ddf94
commit e08c06cbc3
4 changed files with 42 additions and 33 deletions

View File

@ -10,7 +10,6 @@ from types import SimpleNamespace
from unittest.mock import MagicMock, Mock, create_autospec, patch
import pytest
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
from werkzeug.exceptions import Forbidden, NotFound
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -18,6 +17,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from enums.cloud_plan import CloudPlan
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
from models import Account, TenantAccountRole
from models.dataset import (
ChildChunk,

View File

@ -190,7 +190,7 @@ class TestDatasetServiceValidation:
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_dataset_model_setting(dataset)
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
@ -201,7 +201,7 @@ class TestDatasetServiceValidation:
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Embedding Model available"):
DatasetService.check_dataset_model_setting(dataset)
@ -210,14 +210,18 @@ class TestDatasetServiceValidation:
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
with pytest.raises(ValueError, match="token missing"):
with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"):
DatasetService.check_dataset_model_setting(dataset)
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup")
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"provider setup"
)
with pytest.raises(ValueError, match="provider setup"):
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
@ -226,7 +230,7 @@ class TestDatasetServiceValidation:
with patch("services.dataset_service.ModelManager") as model_manager_cls:
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
tenant_id="tenant-1",
provider="provider",
model_type=ModelType.RERANK,
@ -235,7 +239,7 @@ class TestDatasetServiceValidation:
def test_check_reranking_model_setting_wraps_bad_request(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Rerank Model available"):
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
@ -251,7 +255,7 @@ class TestDatasetServiceValidation:
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@ -268,7 +272,7 @@ class TestDatasetServiceValidation:
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@ -284,14 +288,14 @@ class TestDatasetServiceValidation:
)
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.return_value = model_instance
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
with pytest.raises(ValueError, match="Model schema not found"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
def test_check_is_multimodal_model_wraps_bad_request_error(self):
with patch("services.dataset_service.ModelManager") as model_manager_cls:
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
with pytest.raises(ValueError, match="No Model available"):
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
@ -323,7 +327,7 @@ class TestDatasetServiceCreationAndUpdate:
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
@ -337,7 +341,7 @@ class TestDatasetServiceCreationAndUpdate:
assert dataset.embedding_model == "default-embedding"
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
assert dataset.provider == "vendor"
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1",
model_type=ModelType.TEXT_EMBEDDING,
)
@ -365,7 +369,7 @@ class TestDatasetServiceCreationAndUpdate:
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
):
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
dataset = DatasetService.create_empty_dataset(
tenant_id="tenant-1",
@ -804,7 +808,7 @@ class TestDatasetServiceCreationAndUpdate:
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService._configure_embedding_model_for_high_quality(
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
@ -836,7 +840,7 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.return_value.get_model_instance.side_effect = error
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error
with pytest.raises(ValueError, match=message):
DatasetService._configure_embedding_model_for_high_quality(
@ -967,7 +971,7 @@ class TestDatasetServiceCreationAndUpdate:
return_value=SimpleNamespace(id="binding-2"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
@ -1002,7 +1006,9 @@ class TestDatasetServiceCreationAndUpdate:
patch("services.dataset_service.current_user", current_user),
patch("services.dataset_service.ModelManager") as model_manager_cls,
):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService._apply_new_embedding_settings(
dataset,
@ -1067,7 +1073,7 @@ class TestDatasetServiceRagPipelineSettings:
return_value=SimpleNamespace(id="binding-1"),
),
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
@ -1161,7 +1167,7 @@ class TestDatasetServiceRagPipelineSettings:
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
DatasetService.update_rag_pipeline_dataset_settings(
session,
@ -1204,7 +1210,7 @@ class TestDatasetServiceRagPipelineSettings:
),
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
provider="provider-two",
model_name="embedding-model-two",
)
@ -1243,7 +1249,9 @@ class TestDatasetServiceRagPipelineSettings:
patch("services.dataset_service.ModelManager") as model_manager_cls,
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
):
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
"token missing"
)
DatasetService.update_rag_pipeline_dataset_settings(
session,

View File

@ -1828,7 +1828,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
) as get_binding,
patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document),
):
model_manager_cls.return_value.get_default_model_instance.return_value = SimpleNamespace(
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace(
model_name="default-embedding",
provider="default-provider",
)
@ -1880,7 +1880,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
):
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
model_manager_cls.return_value.get_default_model_instance.assert_not_called()
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called()
get_binding.assert_called_once_with("explicit-provider", "explicit-model")
assert dataset.embedding_model == "explicit-model"
assert dataset.embedding_model_provider == "explicit-provider"

View File

@ -9,6 +9,7 @@ from .dataset_service_test_helpers import (
DocumentSegment,
IndexStructureType,
MagicMock,
ModelType,
SegmentService,
SegmentUpdateArgs,
SimpleNamespace,
@ -459,7 +460,7 @@ class TestSegmentServiceMutations:
patch("services.dataset_service.naive_utc_now", return_value="now"),
):
mock_redis.lock.return_value = _make_lock_context()
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
@ -571,7 +572,7 @@ class TestSegmentServiceMutations:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
processing_rule_query = MagicMock()
processing_rule_query.where.return_value.first.return_value = processing_rule
@ -618,7 +619,7 @@ class TestSegmentServiceMutations:
) as generate_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
@ -661,7 +662,7 @@ class TestSegmentServiceMutations:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = existing_summary
@ -900,7 +901,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.dataset_service.naive_utc_now", return_value="now"),
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
summary_query = MagicMock()
summary_query.where.return_value.first.return_value = None
refreshed_query = MagicMock()
@ -947,7 +948,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
):
mock_redis.get.return_value = None
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
update_summary.side_effect = RuntimeError("summary failed")
processing_rule_query = MagicMock()
@ -966,9 +967,9 @@ class TestSegmentServiceAdditionalRegenerationBranches:
)
assert result is refreshed_segment
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
tenant_id="tenant-1",
model_type="text-embedding",
model_type=ModelType.TEXT_EMBEDDING,
)
vector_service.generate_child_chunks.assert_called_once_with(
segment,