diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index c3bbe8fc09..8969825be4 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -70,12 +70,32 @@ class ProviderManager: Request-bound managers may carry caller identity in that runtime, and the resulting ``ProviderConfiguration`` objects must reuse it for downstream model-type and schema lookups. + + Configuration assembly is cached per manager instance so call chains that + share one request-scoped manager can reuse the same provider graph instead + of rebuilding it for every lookup. Call ``clear_configurations_cache()`` + when a long-lived manager needs to observe writes performed within the same + instance scope. """ + decoding_rsa_key: Any | None + decoding_cipher_rsa: Any | None + _model_runtime: ModelRuntime + _configurations_cache: dict[str, ProviderConfigurations] + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None self._model_runtime = model_runtime + self._configurations_cache = {} + + def clear_configurations_cache(self, tenant_id: str | None = None) -> None: + """Drop assembled provider configurations cached on this manager instance.""" + if tenant_id is None: + self._configurations_cache.clear() + return + + self._configurations_cache.pop(tenant_id, None) def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -114,6 +134,10 @@ class ProviderManager: :param tenant_id: :return: """ + cached_configurations = self._configurations_cache.get(tenant_id) + if cached_configurations is not None: + return cached_configurations + # Get all provider records of the workspace provider_name_to_provider_records_dict = self._get_all_providers(tenant_id) @@ -273,6 +297,8 @@ class ProviderManager: provider_configurations[str(provider_id_entity)] = provider_configuration + self._configurations_cache[tenant_id] = provider_configurations + # Return the encapsulated object return provider_configurations diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index f45b43082c..a5a542c94f 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -372,6 +372,78 @@ def test_get_configurations_binds_manager_runtime_to_provider_configuration( provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) +def test_get_configurations_reuses_cached_result_for_same_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory) as mock_factory_cls, + patch( + "core.provider_manager.ProviderConfiguration", + return_value=provider_configuration, + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is second + mock_get_all_providers.assert_called_once_with("tenant-id") + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + mock_provider_configuration.assert_called_once() + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_clear_configurations_cache_rebuilds_requested_tenant(mocker: MockerFixture, mock_provider_entity): + manager = _build_provider_manager(mocker) + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + provider_configuration_first = Mock() + provider_configuration_second = Mock() + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}) as mock_get_all_providers, + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch( + "core.provider_manager.ProviderConfiguration", + side_effect=[provider_configuration_first, provider_configuration_second], + ) as mock_provider_configuration, + ): + first = manager.get_configurations("tenant-id") + manager.clear_configurations_cache("tenant-id") + second = manager.get_configurations("tenant-id") + + assert first is not second + assert mock_get_all_providers.call_count == 2 + assert mock_provider_configuration.call_count == 2 + provider_configuration_first.bind_model_runtime.assert_called_once_with(manager._model_runtime) + provider_configuration_second.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): manager = _build_provider_manager(mocker) provider_configuration = Mock()