fix: sync 35528 (#35539)

This commit is contained in:
Yunlu Wen
2026-04-24 11:59:33 +08:00
committed by GitHub
parent 38fc2a6574
commit 48e13f65dc
4 changed files with 89 additions and 8 deletions

View File

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from copy import deepcopy
from typing import Any from typing import Any
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
class DifyCredentialsProvider: class DifyCredentialsProvider:
"""Resolves and returns LLM credentials for a given provider and model.
Fetched credentials are stored in :attr:`credentials_cache` and reused for
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
Because of that cache, a single instance can return stale credentials after
the tenant or provider configuration changes (e.g. API key rotation).
Do **not** keep one instance for the lifetime of a process or across
unrelated invocations. Create a new provider per request, workflow run, or
other bounded scope where up-to-date credentials matter.
"""
tenant_id: str tenant_id: str
provider_manager: ProviderManager provider_manager: ProviderManager
credentials_cache: dict[tuple[str, str], dict[str, Any]]
def __init__( def __init__(
self, self,
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
user_id=run_context.user_id, user_id=run_context.user_id,
) )
self.provider_manager = provider_manager self.provider_manager = provider_manager
self.credentials_cache = {}
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
if (provider_name, model_name) in self.credentials_cache:
return deepcopy(self.credentials_cache[(provider_name, model_name)])
provider_configurations = self.provider_manager.get_configurations(self.tenant_id) provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name) provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration: if not provider_configuration:
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
if credentials is None: if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
return credentials return credentials
@ -65,7 +84,8 @@ class DifyModelFactory:
provider_manager=create_plugin_provider_manager( provider_manager=create_plugin_provider_manager(
tenant_id=run_context.tenant_id, tenant_id=run_context.tenant_id,
user_id=run_context.user_id, user_id=run_context.user_id,
) ),
enable_credentials_cache=True,
) )
self.model_manager = model_manager self.model_manager = model_manager
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
tenant_id=run_context.tenant_id, tenant_id=run_context.tenant_id,
user_id=run_context.user_id, user_id=run_context.user_id,
) )
model_manager = ModelManager(provider_manager=provider_manager) model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
return ( return (
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),

View File

@ -1,5 +1,6 @@
import logging import logging
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
from copy import deepcopy
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
from configs import dify_config from configs import dify_config
@ -36,11 +37,13 @@ class ModelInstance:
Model instance class. Model instance class.
""" """
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
self.provider_model_bundle = provider_model_bundle self.provider_model_bundle = provider_model_bundle
self.model_name = model self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model) if credentials is None:
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.credentials = credentials
# Runtime LLM invocation fields. # Runtime LLM invocation fields.
self.parameters: Mapping[str, Any] = {} self.parameters: Mapping[str, Any] = {}
self.stop: Sequence[str] = () self.stop: Sequence[str] = ()
@ -434,8 +437,30 @@ class ModelInstance:
class ModelManager: class ModelManager:
def __init__(self, provider_manager: ProviderManager): """Resolves :class:`ModelInstance` objects for a tenant and provider.
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
``(tenant_id, provider, model_type, model)`` are stored in
``_credentials_cache`` and reused. That can return **stale** credentials after
API keys or provider settings change, so a manager constructed with
``enable_credentials_cache=True`` should not be kept for the lifetime of a
process or shared across unrelated work. Prefer a new manager per request,
workflow run, or similar bounded scope.
The default is ``enable_credentials_cache=False``; in that mode the internal
credential cache is not populated, and each ``get_model_instance`` call
loads credentials from the current provider configuration.
"""
def __init__(
self,
provider_manager: ProviderManager,
*,
enable_credentials_cache: bool = False,
) -> None:
self._provider_manager = provider_manager self._provider_manager = provider_manager
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
self._enable_credentials_cache = enable_credentials_cache
@classmethod @classmethod
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
@ -463,8 +488,19 @@ class ModelManager:
tenant_id=tenant_id, provider=provider, model_type=model_type tenant_id=tenant_id, provider=provider, model_type=model_type
) )
model_instance = ModelInstance(provider_model_bundle, model) cred_cache_key = (tenant_id, provider, model_type.value, model)
return model_instance
if cred_cache_key in self._credentials_cache:
return ModelInstance(
provider_model_bundle,
model,
deepcopy(self._credentials_cache[cred_cache_key]),
)
ret = ModelInstance(provider_model_bundle, model)
if self._enable_credentials_cache:
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
return ret
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
""" """

View File

@ -5,6 +5,7 @@ import uuid
from datetime import datetime from datetime import datetime
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from cachetools.func import ttl_cache
from pydantic import BaseModel, ConfigDict, Field, model_validator from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config from configs import dify_config
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
class EnterpriseService: class EnterpriseService:
@classmethod @classmethod
@ttl_cache(ttl=5)
def get_info(cls): def get_info(cls):
return EnterpriseRequest.send_request("GET", "/info") return EnterpriseRequest.send_request("GET", "/info")

View File

@ -5,7 +5,7 @@ import redis
from pytest_mock import MockerFixture from pytest_mock import MockerFixture
from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.entities.provider_entities import ModelLoadBalancingConfiguration
from core.model_manager import LBModelManager from core.model_manager import LBModelManager, ModelManager
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.entities.model_entities import ModelType
@ -40,6 +40,29 @@ def lb_model_manager():
return lb_model_manager return lb_model_manager
def test_model_manager_with_cache_enabled_reuses_stored_credentials():
"""With ``enable_credentials_cache=True``, later calls for the same key return cached creds."""
provider_manager = MagicMock()
bundle = MagicMock()
bundle.configuration.provider.provider = "openai"
bundle.configuration.tenant_id = "tenant-1"
bundle.configuration.model_settings = []
bundle.model_type_instance.model_type = ModelType.LLM
get_creds = MagicMock(return_value={"api_key": "first"})
bundle.configuration.get_current_credentials = get_creds
provider_manager.get_provider_model_bundle.return_value = bundle
manager = ModelManager(provider_manager, enable_credentials_cache=True)
first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert first.credentials == {"api_key": "first"}
get_creds.assert_called_once()
get_creds.return_value = {"api_key": "second"}
second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
assert second.credentials == {"api_key": "first"}
get_creds.assert_called_once()
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager): def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
# initialize redis client # initialize redis client
redis_client.initialize(redis.Redis()) redis_client.initialize(redis.Redis())