mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
refactor(tool): implement multi provider credentials support
This commit is contained in:
77
api/core/helper/provider_cache.py
Normal file
77
api/core/helper/provider_cache.py
Normal file
@ -0,0 +1,77 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError
|
||||
from typing import Any, Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ProviderCredentialsCache(ABC):
|
||||
"""Base class for provider credentials cache"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self.cache_key = self._generate_cache_key(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
"""Generate cache key based on subclass implementation"""
|
||||
pass
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
cached_credentials = redis_client.get(self.cache_key)
|
||||
if cached_credentials:
|
||||
try:
|
||||
cached_credentials = cached_credentials.decode("utf-8")
|
||||
return dict(json.loads(cached_credentials))
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(config))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
redis_client.delete(self.cache_key)
|
||||
|
||||
|
||||
class GenericProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for generic provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, identity_id: str):
|
||||
super().__init__(tenant_id=tenant_id, identity_id=identity_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
identity_id = kwargs["identity_id"]
|
||||
return f"generic_provider_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
class ToolProviderCredentialsCache(ProviderCredentialsCache):
|
||||
"""Cache for tool provider credentials"""
|
||||
|
||||
def __init__(self, tenant_id: str, provider: str, credential_id: str):
|
||||
super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
|
||||
|
||||
def _generate_cache_key(self, **kwargs) -> str:
|
||||
tenant_id = kwargs["tenant_id"]
|
||||
provider = kwargs["provider"]
|
||||
credential_id = kwargs["credential_id"]
|
||||
return f"provider_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
|
||||
|
||||
|
||||
class NoOpProviderCredentialCache:
|
||||
"""No-op provider credential cache"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider credentials"""
|
||||
return None
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider credentials"""
|
||||
pass
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider credentials"""
|
||||
pass
|
||||
@ -1,51 +0,0 @@
|
||||
import json
|
||||
from enum import Enum
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
|
||||
class ToolProviderCredentialsCacheType(Enum):
|
||||
PROVIDER = "tool_provider"
|
||||
ENDPOINT = "endpoint"
|
||||
|
||||
|
||||
class ToolProviderCredentialsCache:
|
||||
def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
|
||||
self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""
|
||||
Get cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
cached_provider_credentials = redis_client.get(self.cache_key)
|
||||
if cached_provider_credentials:
|
||||
try:
|
||||
cached_provider_credentials = cached_provider_credentials.decode("utf-8")
|
||||
cached_provider_credentials = json.loads(cached_provider_credentials)
|
||||
except JSONDecodeError:
|
||||
return None
|
||||
|
||||
return dict(cached_provider_credentials)
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, credentials: dict) -> None:
|
||||
"""
|
||||
Cache model provider credentials.
|
||||
|
||||
:param credentials: provider credentials
|
||||
:return:
|
||||
"""
|
||||
redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
|
||||
|
||||
def delete(self) -> None:
|
||||
"""
|
||||
Delete cached model provider credentials.
|
||||
|
||||
:return:
|
||||
"""
|
||||
redis_client.delete(self.cache_key)
|
||||
@ -1,12 +1,12 @@
|
||||
from core.plugin.entities.request import RequestInvokeEncrypt
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.configuration import create_generic_encrypter
|
||||
from models.account import Tenant
|
||||
|
||||
|
||||
class PluginEncrypter:
|
||||
@classmethod
|
||||
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
|
||||
encrypter = ProviderConfigEncrypter(
|
||||
encrypter, cache = create_generic_encrypter(
|
||||
tenant_id=tenant.id,
|
||||
config=payload.config,
|
||||
provider_type=payload.namespace,
|
||||
@ -22,7 +22,7 @@ class PluginEncrypter:
|
||||
"data": encrypter.decrypt(payload.data),
|
||||
}
|
||||
elif payload.opt == "clear":
|
||||
encrypter.delete_tool_credentials_cache()
|
||||
cache.delete()
|
||||
return {
|
||||
"data": {},
|
||||
}
|
||||
|
||||
@ -105,20 +105,34 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
return self.tools
|
||||
|
||||
def get_credentials_schema(
|
||||
self, credential_type: ToolProviderCredentialType = ToolProviderCredentialType.API_KEY
|
||||
) -> list[ProviderConfig]:
|
||||
def get_credentials_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if credential_type == ToolProviderCredentialType.OAUTH2:
|
||||
return self.get_credentials_schema_by_type(ToolProviderCredentialType.API_KEY.value)
|
||||
|
||||
def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:param credential_type: the type of the credential
|
||||
:return: the credentials schema of the provider
|
||||
"""
|
||||
if credential_type == ToolProviderCredentialType.OAUTH2.value:
|
||||
return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
|
||||
elif credential_type == ToolProviderCredentialType.API_KEY:
|
||||
if credential_type == ToolProviderCredentialType.API_KEY.value:
|
||||
return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
|
||||
else:
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
raise ValueError(f"Invalid credential type: {credential_type}")
|
||||
|
||||
def get_oauth_client_schema(self) -> list[ProviderConfig]:
|
||||
"""
|
||||
returns the oauth client schema of the provider
|
||||
|
||||
:return: the oauth client schema
|
||||
"""
|
||||
return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@ -141,7 +155,11 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
return (
|
||||
self.entity.credentials_schema is not None
|
||||
and len(self.entity.credentials_schema) != 0
|
||||
or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
|
||||
@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
import contexts
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.tool import PluginToolManager
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
@ -38,12 +39,16 @@ from core.tools.entities.tool_entities import (
|
||||
ApiProviderAuthType,
|
||||
ToolInvokeFrom,
|
||||
ToolParameter,
|
||||
ToolProviderCredentialType,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.errors import ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, ToolParameterConfigurationManager
|
||||
from core.tools.utils.configuration import (
|
||||
ProviderConfigEncrypter,
|
||||
ToolParameterConfigurationManager,
|
||||
create_encrypter,
|
||||
create_generic_encrypter,
|
||||
)
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@ -206,19 +211,18 @@ class ToolManager:
|
||||
|
||||
# decrypt the credentials
|
||||
credentials = builtin_provider.credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema(
|
||||
ToolProviderCredentialType.of(builtin_provider.credential_type)
|
||||
)
|
||||
for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
|
||||
],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
|
||||
),
|
||||
)
|
||||
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
BuiltinTool,
|
||||
@ -235,22 +239,18 @@ class ToolManager:
|
||||
|
||||
elif provider_type == ToolProviderType.API:
|
||||
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
|
||||
|
||||
# decrypt the credentials
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
|
||||
provider_type=api_provider.provider_type.value,
|
||||
provider_identity=api_provider.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
return cast(
|
||||
ApiTool,
|
||||
api_provider.get_tool(tool_name).fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
tenant_id=tenant_id,
|
||||
credentials=decrypted_credentials,
|
||||
credentials=encrypter.decrypt(credentials),
|
||||
invoke_from=invoke_from,
|
||||
tool_invoke_from=tool_invoke_from,
|
||||
)
|
||||
@ -730,7 +730,7 @@ class ToolManager:
|
||||
ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
|
||||
)
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tool_configuration = ProviderConfigEncrypter.create_cached(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
|
||||
provider_type=controller.provider_type.value,
|
||||
|
||||
@ -1,12 +1,10 @@
|
||||
from copy import deepcopy
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Any, Optional, Protocol
|
||||
|
||||
from core.entities.provider_entities import BasicProviderConfig
|
||||
from core.helper import encrypter
|
||||
from core.helper.provider_cache import GenericProviderCredentialsCache
|
||||
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
@ -14,11 +12,38 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
|
||||
|
||||
class ProviderConfigEncrypter(BaseModel):
|
||||
class ProviderConfigCache(Protocol):
|
||||
"""
|
||||
Interface for provider configuration cache operations
|
||||
"""
|
||||
|
||||
def get(self) -> Optional[dict]:
|
||||
"""Get cached provider configuration"""
|
||||
...
|
||||
|
||||
def set(self, config: dict[str, Any]) -> None:
|
||||
"""Cache provider configuration"""
|
||||
...
|
||||
|
||||
def delete(self) -> None:
|
||||
"""Delete cached provider configuration"""
|
||||
...
|
||||
|
||||
|
||||
class ProviderConfigEncrypter:
|
||||
tenant_id: str
|
||||
config: list[BasicProviderConfig]
|
||||
provider_type: str
|
||||
provider_identity: str
|
||||
provider_config_cache: ProviderConfigCache
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
config: list[BasicProviderConfig],
|
||||
provider_config_cache: ProviderConfigCache,
|
||||
):
|
||||
self.tenant_id = tenant_id
|
||||
self.config = config
|
||||
self.provider_config_cache = provider_config_cache
|
||||
|
||||
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
@ -72,18 +97,13 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
|
||||
return data
|
||||
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||
def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
|
||||
"""
|
||||
decrypt tool credentials with tenant id
|
||||
|
||||
return a deep copy of credentials with decrypted values
|
||||
"""
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cached_credentials = cache.get()
|
||||
cached_credentials = self.provider_config_cache.get()
|
||||
if cached_credentials:
|
||||
return cached_credentials
|
||||
data = self._deep_copy(data)
|
||||
@ -104,16 +124,24 @@ class ProviderConfigEncrypter(BaseModel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cache.set(data)
|
||||
self.provider_config_cache.set(data)
|
||||
return data
|
||||
|
||||
def delete_tool_credentials_cache(self):
|
||||
cache = ToolProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
def create_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache
|
||||
):
|
||||
return ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id, config=config, provider_config_cache=cache
|
||||
), cache
|
||||
|
||||
|
||||
def create_generic_encrypter(
|
||||
tenant_id: str, config: list[BasicProviderConfig], provider_type: str, provider_identity: str
|
||||
):
|
||||
cache = GenericProviderCredentialsCache(tenant_id=tenant_id, identity_id=f"{provider_type}.{provider_identity}")
|
||||
encrypt = ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache)
|
||||
return encrypt, cache
|
||||
|
||||
|
||||
class ToolParameterConfigurationManager:
|
||||
|
||||
Reference in New Issue
Block a user