mirror of
https://github.com/langgenius/dify.git
synced 2026-03-16 12:27:42 +08:00
feat: Remove GPT-4 special-casing from default model selection (#33458)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -305,9 +305,7 @@ class ProviderManager:
|
||||
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
|
||||
|
||||
if available_models:
|
||||
available_model = next(
|
||||
(model for model in available_models if model.model == "gpt-4"), available_models[0]
|
||||
)
|
||||
available_model = available_models[0]
|
||||
|
||||
default_model = TenantDefaultModel(
|
||||
tenant_id=tenant_id,
|
||||
|
||||
@ -1,32 +1,34 @@
|
||||
from unittest.mock import Mock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.entities.provider_entities import ModelSettings
|
||||
from core.provider_manager import ProviderManager
|
||||
from dify_graph.model_runtime.entities.common_entities import I18nObject
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider_entity(mocker: MockerFixture):
|
||||
mock_entity = mocker.Mock()
|
||||
def mock_provider_entity():
|
||||
mock_entity = Mock()
|
||||
mock_entity.provider = "openai"
|
||||
mock_entity.configurate_methods = ["predefined-model"]
|
||||
mock_entity.supported_model_types = [ModelType.LLM]
|
||||
|
||||
# Use PropertyMock to ensure credential_form_schemas is iterable
|
||||
provider_credential_schema = mocker.Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
provider_credential_schema = Mock()
|
||||
type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
mock_entity.provider_credential_schema = provider_credential_schema
|
||||
|
||||
model_credential_schema = mocker.Mock()
|
||||
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
|
||||
model_credential_schema = Mock()
|
||||
type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
|
||||
mock_entity.model_credential_schema = model_credential_schema
|
||||
|
||||
return mock_entity
|
||||
|
||||
|
||||
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
|
||||
assert result[0].load_balancing_configs[1].name == "first"
|
||||
|
||||
|
||||
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings_only_one_lb(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
|
||||
ps = ProviderModelSetting(
|
||||
@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
]
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
|
||||
def test__to_model_settings_lb_disabled(mock_provider_entity):
|
||||
# Mocking the inputs
|
||||
ps = ProviderModelSetting(
|
||||
tenant_id="tenant_id",
|
||||
@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
load_balancing_model_configs[0].id = "id1"
|
||||
load_balancing_model_configs[1].id = "id2"
|
||||
|
||||
mocker.patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
|
||||
)
|
||||
with patch(
|
||||
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
|
||||
return_value={"openai_api_key": "fake_key"},
|
||||
):
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
provider_manager = ProviderManager()
|
||||
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
# Running the method
|
||||
result = provider_manager._to_model_settings(
|
||||
provider_entity=mock_provider_entity,
|
||||
provider_model_settings=provider_model_settings,
|
||||
load_balancing_model_configs=load_balancing_model_configs,
|
||||
)
|
||||
|
||||
# Asserting that the result is as expected
|
||||
assert len(result) == 1
|
||||
@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
|
||||
assert result[0].model_type == ModelType.LLM
|
||||
assert result[0].enabled is True
|
||||
assert len(result[0].load_balancing_configs) == 0
|
||||
|
||||
|
||||
def test_get_default_model_uses_first_available_active_model():
|
||||
mock_session = Mock()
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
provider_configurations = Mock()
|
||||
provider_configurations.get_models.return_value = [
|
||||
Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")),
|
||||
Mock(model="gpt-4", provider=Mock(provider="openai")),
|
||||
]
|
||||
|
||||
manager = ProviderManager()
|
||||
with (
|
||||
patch("core.provider_manager.db.session", mock_session),
|
||||
patch.object(manager, "get_configurations", return_value=provider_configurations),
|
||||
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
|
||||
):
|
||||
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
|
||||
provider="openai",
|
||||
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
|
||||
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
|
||||
supported_model_types=[ModelType.LLM],
|
||||
)
|
||||
|
||||
result = manager.get_default_model("tenant-id", ModelType.LLM)
|
||||
|
||||
assert result is not None
|
||||
assert result.model == "gpt-3.5-turbo"
|
||||
assert result.provider.provider == "openai"
|
||||
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
|
||||
mock_session.add.assert_called_once()
|
||||
saved_default_model = mock_session.add.call_args.args[0]
|
||||
assert saved_default_model.model_name == "gpt-3.5-turbo"
|
||||
assert saved_default_model.provider_name == "openai"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
Reference in New Issue
Block a user