diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c29a463bb6..c538a557fb 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -305,9 +305,7 @@ class ProviderManager: available_models = provider_configurations.get_models(model_type=model_type, only_active=True) if available_models: - available_model = next( - (model for model in available_models if model.model == "gpt-4"), available_models[0] - ) + available_model = available_models[0] default_model = TenantDefaultModel( tenant_id=tenant_id, diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 3abfb8c9f8..69567c54eb 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,32 +1,34 @@ +from unittest.mock import Mock, PropertyMock, patch + import pytest -from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager +from dify_graph.model_runtime.entities.common_entities import I18nObject from dify_graph.model_runtime.entities.model_entities import ModelType from models.provider import LoadBalancingModelConfig, ProviderModelSetting @pytest.fixture -def mock_provider_entity(mocker: MockerFixture): - mock_entity = mocker.Mock() +def mock_provider_entity(): + mock_entity = Mock() mock_entity.provider = "openai" mock_entity.configurate_methods = ["predefined-model"] mock_entity.supported_model_types = [ModelType.LLM] # Use PropertyMock to ensure credential_form_schemas is iterable - provider_credential_schema = mocker.Mock() - type(provider_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[]) + provider_credential_schema = Mock() + type(provider_credential_schema).credential_form_schemas = PropertyMock(return_value=[]) mock_entity.provider_credential_schema = provider_credential_schema - model_credential_schema = mocker.Mock() - type(model_credential_schema).credential_form_schemas = mocker.PropertyMock(return_value=[]) + model_credential_schema = Mock() + type(model_credential_schema).credential_form_schemas = PropertyMock(return_value=[]) mock_entity.model_credential_schema = model_credential_schema return mock_entity -def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -63,18 +65,18 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): load_balancing_model_configs[0].id = "id1" load_balancing_model_configs[1].id = "id2" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -87,7 +89,7 @@ def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings_only_one_lb(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -113,18 +115,18 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent ] load_balancing_model_configs[0].id = "id1" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -135,7 +137,7 @@ def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_ent assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): +def test__to_model_settings_lb_disabled(mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -170,18 +172,18 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent load_balancing_model_configs[0].id = "id1" load_balancing_model_configs[1].id = "id2" - mocker.patch( - "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"} - ) + with patch( + "core.helper.model_provider_cache.ProviderCredentialsCache.get", + return_value={"openai_api_key": "fake_key"}, + ): + provider_manager = ProviderManager() - provider_manager = ProviderManager() - - # Running the method - result = provider_manager._to_model_settings( - provider_entity=mock_provider_entity, - provider_model_settings=provider_model_settings, - load_balancing_model_configs=load_balancing_model_configs, - ) + # Running the method + result = provider_manager._to_model_settings( + provider_entity=mock_provider_entity, + provider_model_settings=provider_model_settings, + load_balancing_model_configs=load_balancing_model_configs, + ) # Asserting that the result is as expected assert len(result) == 1 @@ -190,3 +192,39 @@ def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_ent assert result[0].model_type == ModelType.LLM assert result[0].enabled is True assert len(result[0].load_balancing_configs) == 0 + + +def test_get_default_model_uses_first_available_active_model(): + mock_session = Mock() + mock_session.scalar.return_value = None + + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-3.5-turbo", provider=Mock(provider="openai")), + Mock(model="gpt-4", provider=Mock(provider="openai")), + ] + + manager = ProviderManager() + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is not None + assert result.model == "gpt-3.5-turbo" + assert result.provider.provider == "openai" + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_session.add.assert_called_once() + saved_default_model = mock_session.add.call_args.args[0] + assert saved_default_model.model_name == "gpt-3.5-turbo" + assert saved_default_model.provider_name == "openai" + mock_session.commit.assert_called_once()