feat: Remove GPT-4 special-casing from default model selection (#33458)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
-LAN-
2026-03-16 03:09:20 +08:00
committed by GitHub
parent b09a75aae0
commit 101d6d4d04
2 changed files with 82 additions and 46 deletions

View File

@ -305,9 +305,7 @@ class ProviderManager:
available_models = provider_configurations.get_models(model_type=model_type, only_active=True)
if available_models:
available_model = next(
(model for model in available_models if model.model == "gpt-4"), available_models[0]
)
available_model = available_models[0]
default_model = TenantDefaultModel(
tenant_id=tenant_id,

View File

@ -1,32 +1,34 @@
from unittest.mock import Mock, PropertyMock, patch
import pytest
from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelSettings
from core.provider_manager import ProviderManager
from dify_graph.model_runtime.entities.common_entities import I18nObject
from dify_graph.model_runtime.entities.model_entities import ModelType
from models.provider import LoadBalancingModelConfig, ProviderModelSetting
@pytest.fixture
def mock_provider_entity(mocker: MockerFixture):
mock_entity = mocker.Mock()
def mock_provider_entity():
mock_entity = Mock()
mock_entity.provider = "openai"
mock_entity.configurate_methods = ["predefined-model"]
mock_entity.supported_model_types = [ModelType.LLM]
# Use PropertyMock to ensure credential_form_schemas is iterable
provider_credential_schema = mocker.Mock()
type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
provider_credential_schema = Mock()
type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
mock_entity.provider_credential_schema = provider_credential_schema
model_credential_schema = mocker.Mock()
type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[])
model_credential_schema = Mock()
type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[])
mock_entity.model_credential_schema = model_credential_schema
return mock_entity
def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity):
assert result[0].load_balancing_configs[1].name == "first"
def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings_only_one_lb(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
]
load_balancing_model_configs[0].id = "id1"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent
assert len(result[0].load_balancing_configs) == 0
def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity):
def test__to_model_settings_lb_disabled(mock_provider_entity):
# Mocking the inputs
ps = ProviderModelSetting(
tenant_id="tenant_id",
@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
load_balancing_model_configs[0].id = "id1"
load_balancing_model_configs[1].id = "id2"
mocker.patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}
)
with patch(
"core.helper.model_provider_cache.ProviderCredentialsCache.get",
return_value={"openai_api_key": "fake_key"},
):
provider_manager = ProviderManager()
provider_manager = ProviderManager()
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Running the method
result = provider_manager._to_model_settings(
provider_entity=mock_provider_entity,
provider_model_settings=provider_model_settings,
load_balancing_model_configs=load_balancing_model_configs,
)
# Asserting that the result is as expected
assert len(result) == 1
@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent
assert result[0].model_type == ModelType.LLM
assert result[0].enabled is True
assert len(result[0].load_balancing_configs) == 0
def test_get_default_model_uses_first_available_active_model():
mock_session = Mock()
mock_session.scalar.return_value = None
provider_configurations = Mock()
provider_configurations.get_models.return_value = [
Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")),
Mock(model="gpt-4", provider=Mock(provider="openai")),
]
manager = ProviderManager()
with (
patch("core.provider_manager.db.session", mock_session),
patch.object(manager, "get_configurations", return_value=provider_configurations),
patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls,
):
mock_factory_cls.return_value.get_provider_schema.return_value = Mock(
provider="openai",
label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"),
icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"),
supported_model_types=[ModelType.LLM],
)
result = manager.get_default_model("tenant-id", ModelType.LLM)
assert result is not None
assert result.model == "gpt-3.5-turbo"
assert result.provider.provider == "openai"
provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True)
mock_session.add.assert_called_once()
saved_default_model = mock_session.add.call_args.args[0]
assert saved_default_model.model_name == "gpt-3.5-turbo"
assert saved_default_model.provider_name == "openai"
mock_session.commit.assert_called_once()