mirror of
https://github.com/langgenius/dify.git
synced 2026-03-29 09:59:59 +08:00
fix: import path (#34124)
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
@ -10,7 +10,6 @@ from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
@ -18,6 +17,7 @@ from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||
from core.rag.index_processor.constant.index_type import IndexStructureType
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from models import Account, TenantAccountRole
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
|
||||
@ -190,7 +190,7 @@ class TestDatasetServiceValidation:
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=dataset.tenant_id,
|
||||
provider=dataset.embedding_model_provider,
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
@ -201,7 +201,7 @@ class TestDatasetServiceValidation:
|
||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
|
||||
with pytest.raises(ValueError, match="No Embedding Model available"):
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
@ -210,14 +210,18 @@ class TestDatasetServiceValidation:
|
||||
dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality")
|
||||
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||
"token missing"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="token missing"):
|
||||
with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"):
|
||||
DatasetService.check_dataset_model_setting(dataset)
|
||||
|
||||
def test_check_embedding_model_setting_wraps_provider_token_error_description(self):
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("provider setup")
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||
"provider setup"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="provider setup"):
|
||||
DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model")
|
||||
@ -226,7 +230,7 @@ class TestDatasetServiceValidation:
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
||||
|
||||
model_manager_cls.return_value.get_model_instance.assert_called_once_with(
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
provider="provider",
|
||||
model_type=ModelType.RERANK,
|
||||
@ -235,7 +239,7 @@ class TestDatasetServiceValidation:
|
||||
|
||||
def test_check_reranking_model_setting_wraps_bad_request(self):
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
|
||||
with pytest.raises(ValueError, match="No Rerank Model available"):
|
||||
DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker")
|
||||
@ -251,7 +255,7 @@ class TestDatasetServiceValidation:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||
|
||||
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||
|
||||
@ -268,7 +272,7 @@ class TestDatasetServiceValidation:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||
|
||||
result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||
|
||||
@ -284,14 +288,14 @@ class TestDatasetServiceValidation:
|
||||
)
|
||||
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.return_value = model_instance
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance
|
||||
|
||||
with pytest.raises(ValueError, match="Model schema not found"):
|
||||
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||
|
||||
def test_check_is_multimodal_model_wraps_bad_request_error(self):
|
||||
with patch("services.dataset_service.ModelManager") as model_manager_cls:
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError()
|
||||
|
||||
with pytest.raises(ValueError, match="No Model available"):
|
||||
DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model")
|
||||
@ -323,7 +327,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch.object(DatasetService, "check_embedding_model_setting") as check_embedding,
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
model_manager_cls.return_value.get_default_model_instance.return_value = default_embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model
|
||||
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id="tenant-1",
|
||||
@ -337,7 +341,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
assert dataset.embedding_model == "default-embedding"
|
||||
assert dataset.permission == DatasetPermissionEnum.ONLY_ME
|
||||
assert dataset.provider == "vendor"
|
||||
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
@ -365,7 +369,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch.object(DatasetService, "check_reranking_model_setting") as check_reranking,
|
||||
):
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id="tenant-1",
|
||||
@ -804,7 +808,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
return_value=SimpleNamespace(id="binding-1"),
|
||||
),
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
DatasetService._configure_embedding_model_for_high_quality(
|
||||
{"embedding_model_provider": "provider", "embedding_model": "embedding-model"},
|
||||
@ -836,7 +840,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.current_user", current_user),
|
||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = error
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error
|
||||
|
||||
with pytest.raises(ValueError, match=message):
|
||||
DatasetService._configure_embedding_model_for_high_quality(
|
||||
@ -967,7 +971,7 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
return_value=SimpleNamespace(id="binding-2"),
|
||||
),
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||
provider="provider-two",
|
||||
model_name="embedding-model-two",
|
||||
)
|
||||
@ -1002,7 +1006,9 @@ class TestDatasetServiceCreationAndUpdate:
|
||||
patch("services.dataset_service.current_user", current_user),
|
||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||
"token missing"
|
||||
)
|
||||
|
||||
DatasetService._apply_new_embedding_settings(
|
||||
dataset,
|
||||
@ -1067,7 +1073,7 @@ class TestDatasetServiceRagPipelineSettings:
|
||||
return_value=SimpleNamespace(id="binding-1"),
|
||||
),
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration)
|
||||
|
||||
@ -1161,7 +1167,7 @@ class TestDatasetServiceRagPipelineSettings:
|
||||
),
|
||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session,
|
||||
@ -1204,7 +1210,7 @@ class TestDatasetServiceRagPipelineSettings:
|
||||
),
|
||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace(
|
||||
provider="provider-two",
|
||||
model_name="embedding-model-two",
|
||||
)
|
||||
@ -1243,7 +1249,9 @@ class TestDatasetServiceRagPipelineSettings:
|
||||
patch("services.dataset_service.ModelManager") as model_manager_cls,
|
||||
patch("services.dataset_service.deal_dataset_index_update_task") as update_task,
|
||||
):
|
||||
model_manager_cls.return_value.get_model_instance.side_effect = ProviderTokenNotInitError("token missing")
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError(
|
||||
"token missing"
|
||||
)
|
||||
|
||||
DatasetService.update_rag_pipeline_dataset_settings(
|
||||
session,
|
||||
|
||||
@ -1828,7 +1828,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
||||
) as get_binding,
|
||||
patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document),
|
||||
):
|
||||
model_manager_cls.return_value.get_default_model_instance.return_value = SimpleNamespace(
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace(
|
||||
model_name="default-embedding",
|
||||
provider="default-provider",
|
||||
)
|
||||
@ -1880,7 +1880,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
|
||||
):
|
||||
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
|
||||
|
||||
model_manager_cls.return_value.get_default_model_instance.assert_not_called()
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called()
|
||||
get_binding.assert_called_once_with("explicit-provider", "explicit-model")
|
||||
assert dataset.embedding_model == "explicit-model"
|
||||
assert dataset.embedding_model_provider == "explicit-provider"
|
||||
|
||||
@ -9,6 +9,7 @@ from .dataset_service_test_helpers import (
|
||||
DocumentSegment,
|
||||
IndexStructureType,
|
||||
MagicMock,
|
||||
ModelType,
|
||||
SegmentService,
|
||||
SegmentUpdateArgs,
|
||||
SimpleNamespace,
|
||||
@ -459,7 +460,7 @@ class TestSegmentServiceMutations:
|
||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||
):
|
||||
mock_redis.lock.return_value = _make_lock_context()
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
mock_db.session.query.return_value.where.return_value.scalar.return_value = 1
|
||||
vector_service.create_segments_vector.side_effect = RuntimeError("vector failed")
|
||||
|
||||
@ -571,7 +572,7 @@ class TestSegmentServiceMutations:
|
||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model_instance
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance
|
||||
|
||||
processing_rule_query = MagicMock()
|
||||
processing_rule_query.where.return_value.first.return_value = processing_rule
|
||||
@ -618,7 +619,7 @@ class TestSegmentServiceMutations:
|
||||
) as generate_summary,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
@ -661,7 +662,7 @@ class TestSegmentServiceMutations:
|
||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = existing_summary
|
||||
@ -900,7 +901,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
patch("services.dataset_service.naive_utc_now", return_value="now"),
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.return_value.get_model_instance.return_value = embedding_model
|
||||
model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model
|
||||
summary_query = MagicMock()
|
||||
summary_query.where.return_value.first.return_value = None
|
||||
refreshed_query = MagicMock()
|
||||
@ -947,7 +948,7 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary,
|
||||
):
|
||||
mock_redis.get.return_value = None
|
||||
model_manager_cls.return_value.get_default_model_instance.return_value = embedding_model_instance
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance
|
||||
update_summary.side_effect = RuntimeError("summary failed")
|
||||
|
||||
processing_rule_query = MagicMock()
|
||||
@ -966,9 +967,9 @@ class TestSegmentServiceAdditionalRegenerationBranches:
|
||||
)
|
||||
|
||||
assert result is refreshed_segment
|
||||
model_manager_cls.return_value.get_default_model_instance.assert_called_once_with(
|
||||
model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant-1",
|
||||
model_type="text-embedding",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
)
|
||||
vector_service.generate_child_chunks.assert_called_once_with(
|
||||
segment,
|
||||
|
||||
Reference in New Issue
Block a user