mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 10:28:10 +08:00
fix: sync 35528 (#35539)
This commit is contained in:
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from copy import deepcopy
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity
|
||||||
@ -14,8 +15,21 @@ from graphon.nodes.llm.protocols import CredentialsProvider
|
|||||||
|
|
||||||
|
|
||||||
class DifyCredentialsProvider:
|
class DifyCredentialsProvider:
|
||||||
|
"""Resolves and returns LLM credentials for a given provider and model.
|
||||||
|
|
||||||
|
Fetched credentials are stored in :attr:`credentials_cache` and reused for
|
||||||
|
subsequent ``fetch`` calls for the same ``(provider_name, model_name)``.
|
||||||
|
Because of that cache, a single instance can return stale credentials after
|
||||||
|
the tenant or provider configuration changes (e.g. API key rotation).
|
||||||
|
|
||||||
|
Do **not** keep one instance for the lifetime of a process or across
|
||||||
|
unrelated invocations. Create a new provider per request, workflow run, or
|
||||||
|
other bounded scope where up-to-date credentials matter.
|
||||||
|
"""
|
||||||
|
|
||||||
tenant_id: str
|
tenant_id: str
|
||||||
provider_manager: ProviderManager
|
provider_manager: ProviderManager
|
||||||
|
credentials_cache: dict[tuple[str, str], dict[str, Any]]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -30,8 +44,12 @@ class DifyCredentialsProvider:
|
|||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
)
|
||||||
self.provider_manager = provider_manager
|
self.provider_manager = provider_manager
|
||||||
|
self.credentials_cache = {}
|
||||||
|
|
||||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||||
|
if (provider_name, model_name) in self.credentials_cache:
|
||||||
|
return deepcopy(self.credentials_cache[(provider_name, model_name)])
|
||||||
|
|
||||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||||
provider_configuration = provider_configurations.get(provider_name)
|
provider_configuration = provider_configurations.get(provider_name)
|
||||||
if not provider_configuration:
|
if not provider_configuration:
|
||||||
@ -46,6 +64,7 @@ class DifyCredentialsProvider:
|
|||||||
if credentials is None:
|
if credentials is None:
|
||||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||||
|
|
||||||
|
self.credentials_cache[(provider_name, model_name)] = deepcopy(credentials)
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
|
|
||||||
@ -65,7 +84,8 @@ class DifyModelFactory:
|
|||||||
provider_manager=create_plugin_provider_manager(
|
provider_manager=create_plugin_provider_manager(
|
||||||
tenant_id=run_context.tenant_id,
|
tenant_id=run_context.tenant_id,
|
||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
),
|
||||||
|
enable_credentials_cache=True,
|
||||||
)
|
)
|
||||||
self.model_manager = model_manager
|
self.model_manager = model_manager
|
||||||
|
|
||||||
@ -84,7 +104,7 @@ def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsPro
|
|||||||
tenant_id=run_context.tenant_id,
|
tenant_id=run_context.tenant_id,
|
||||||
user_id=run_context.user_id,
|
user_id=run_context.user_id,
|
||||||
)
|
)
|
||||||
model_manager = ModelManager(provider_manager=provider_manager)
|
model_manager = ModelManager(provider_manager=provider_manager, enable_credentials_cache=True)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager),
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
from collections.abc import Callable, Generator, Iterable, Mapping, Sequence
|
||||||
|
from copy import deepcopy
|
||||||
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
from typing import IO, Any, Literal, Optional, ParamSpec, TypeVar, Union, cast, overload
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -36,11 +37,13 @@ class ModelInstance:
|
|||||||
Model instance class.
|
Model instance class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str, credentials: dict | None = None) -> None:
|
||||||
self.provider_model_bundle = provider_model_bundle
|
self.provider_model_bundle = provider_model_bundle
|
||||||
self.model_name = model
|
self.model_name = model
|
||||||
self.provider = provider_model_bundle.configuration.provider.provider
|
self.provider = provider_model_bundle.configuration.provider.provider
|
||||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
if credentials is None:
|
||||||
|
credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||||
|
self.credentials = credentials
|
||||||
# Runtime LLM invocation fields.
|
# Runtime LLM invocation fields.
|
||||||
self.parameters: Mapping[str, Any] = {}
|
self.parameters: Mapping[str, Any] = {}
|
||||||
self.stop: Sequence[str] = ()
|
self.stop: Sequence[str] = ()
|
||||||
@ -434,8 +437,30 @@ class ModelInstance:
|
|||||||
|
|
||||||
|
|
||||||
class ModelManager:
|
class ModelManager:
|
||||||
def __init__(self, provider_manager: ProviderManager):
|
"""Resolves :class:`ModelInstance` objects for a tenant and provider.
|
||||||
|
|
||||||
|
When ``enable_credentials_cache`` is ``True``, resolved credentials for each
|
||||||
|
``(tenant_id, provider, model_type, model)`` are stored in
|
||||||
|
``_credentials_cache`` and reused. That can return **stale** credentials after
|
||||||
|
API keys or provider settings change, so a manager constructed with
|
||||||
|
``enable_credentials_cache=True`` should not be kept for the lifetime of a
|
||||||
|
process or shared across unrelated work. Prefer a new manager per request,
|
||||||
|
workflow run, or similar bounded scope.
|
||||||
|
|
||||||
|
The default is ``enable_credentials_cache=False``; in that mode the internal
|
||||||
|
credential cache is not populated, and each ``get_model_instance`` call
|
||||||
|
loads credentials from the current provider configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
provider_manager: ProviderManager,
|
||||||
|
*,
|
||||||
|
enable_credentials_cache: bool = False,
|
||||||
|
) -> None:
|
||||||
self._provider_manager = provider_manager
|
self._provider_manager = provider_manager
|
||||||
|
self._credentials_cache: dict[tuple[str, str, str, str], Any] = {}
|
||||||
|
self._enable_credentials_cache = enable_credentials_cache
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager":
|
||||||
@ -463,8 +488,19 @@ class ModelManager:
|
|||||||
tenant_id=tenant_id, provider=provider, model_type=model_type
|
tenant_id=tenant_id, provider=provider, model_type=model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
model_instance = ModelInstance(provider_model_bundle, model)
|
cred_cache_key = (tenant_id, provider, model_type.value, model)
|
||||||
return model_instance
|
|
||||||
|
if cred_cache_key in self._credentials_cache:
|
||||||
|
return ModelInstance(
|
||||||
|
provider_model_bundle,
|
||||||
|
model,
|
||||||
|
deepcopy(self._credentials_cache[cred_cache_key]),
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = ModelInstance(provider_model_bundle, model)
|
||||||
|
if self._enable_credentials_cache:
|
||||||
|
self._credentials_cache[cred_cache_key] = deepcopy(ret.credentials)
|
||||||
|
return ret
|
||||||
|
|
||||||
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import uuid
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cachetools.func import ttl_cache
|
||||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@ -99,6 +100,7 @@ def try_join_default_workspace(account_id: str) -> None:
|
|||||||
|
|
||||||
class EnterpriseService:
|
class EnterpriseService:
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ttl_cache(ttl=5)
|
||||||
def get_info(cls):
|
def get_info(cls):
|
||||||
return EnterpriseRequest.send_request("GET", "/info")
|
return EnterpriseRequest.send_request("GET", "/info")
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import redis
|
|||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
|
|
||||||
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
from core.entities.provider_entities import ModelLoadBalancingConfiguration
|
||||||
from core.model_manager import LBModelManager
|
from core.model_manager import LBModelManager, ModelManager
|
||||||
from extensions.ext_redis import redis_client
|
from extensions.ext_redis import redis_client
|
||||||
from graphon.model_runtime.entities.model_entities import ModelType
|
from graphon.model_runtime.entities.model_entities import ModelType
|
||||||
|
|
||||||
@ -40,6 +40,29 @@ def lb_model_manager():
|
|||||||
return lb_model_manager
|
return lb_model_manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_model_manager_with_cache_enabled_reuses_stored_credentials():
|
||||||
|
"""With ``enable_credentials_cache=True``, later calls for the same key return cached creds."""
|
||||||
|
provider_manager = MagicMock()
|
||||||
|
bundle = MagicMock()
|
||||||
|
bundle.configuration.provider.provider = "openai"
|
||||||
|
bundle.configuration.tenant_id = "tenant-1"
|
||||||
|
bundle.configuration.model_settings = []
|
||||||
|
bundle.model_type_instance.model_type = ModelType.LLM
|
||||||
|
get_creds = MagicMock(return_value={"api_key": "first"})
|
||||||
|
bundle.configuration.get_current_credentials = get_creds
|
||||||
|
provider_manager.get_provider_model_bundle.return_value = bundle
|
||||||
|
|
||||||
|
manager = ModelManager(provider_manager, enable_credentials_cache=True)
|
||||||
|
first = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
|
||||||
|
assert first.credentials == {"api_key": "first"}
|
||||||
|
get_creds.assert_called_once()
|
||||||
|
|
||||||
|
get_creds.return_value = {"api_key": "second"}
|
||||||
|
second = manager.get_model_instance("tenant-1", "openai", ModelType.LLM, "gpt-4")
|
||||||
|
assert second.credentials == {"api_key": "first"}
|
||||||
|
get_creds.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
def test_lb_model_manager_fetch_next(mocker: MockerFixture, lb_model_manager: LBModelManager):
|
||||||
# initialize redis client
|
# initialize redis client
|
||||||
redis_client.initialize(redis.Redis())
|
redis_client.initialize(redis.Redis())
|
||||||
|
|||||||
Reference in New Issue
Block a user