mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 08:58:09 +08:00
refactor(tool): implement multi provider credentials support
This commit is contained in:
@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
|
||||
)
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
from extensions.ext_database import db
|
||||
from models.tools import ApiToolProvider
|
||||
@ -297,28 +297,28 @@ class ApiToolManageService:
|
||||
provider_controller.load_bundled_tools(tool_bundles)
|
||||
|
||||
# get original credentials if exists
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, cache = create_generic_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
original_credentials = encrypter.decrypt(provider.credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
credentials = encrypter.encrypt(credentials)
|
||||
provider.credentials_str = json.dumps(credentials)
|
||||
|
||||
db.session.add(provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
cache.delete()
|
||||
|
||||
# update labels
|
||||
ToolLabelManager.update_tool_labels(provider_controller, labels)
|
||||
@ -416,15 +416,15 @@ class ApiToolManageService:
|
||||
|
||||
# decrypt credentials
|
||||
if db_provider.id:
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=list(provider_controller.get_credentials_schema()),
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
decrypted_credentials = tool_configuration.decrypt(credentials)
|
||||
decrypted_credentials = encrypter.decrypt(credentials)
|
||||
# check if the credential has changed, save the original credential
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
|
||||
for name, value in credentials.items():
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = decrypted_credentials[name]
|
||||
|
||||
@ -8,19 +8,18 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.configuration import create_encrypter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
|
||||
@ -58,20 +57,15 @@ class BuiltinToolManageService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
|
||||
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
|
||||
"""
|
||||
get builtin tool provider info
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_configuration.decrypt(credentials)
|
||||
if builtin_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
entity = ToolTransformService.builtin_provider_to_user_provider(
|
||||
provider_controller=provider_controller,
|
||||
@ -80,7 +74,6 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
entity.original_credentials = {}
|
||||
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
@ -96,32 +89,34 @@ class BuiltinToolManageService:
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder(provider.get_credentials_schema(credential_type))
|
||||
return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type))
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None
|
||||
user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None
|
||||
):
|
||||
"""
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
if db_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
try:
|
||||
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if ToolProviderCredentialType.of(db_provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, db_provider, provider, provider_controller
|
||||
)
|
||||
|
||||
# Decrypt and restore original credentials for masked values
|
||||
original_credentials = tool_configuration.decrypt(provider.credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
|
||||
original_credentials = encrypter.decrypt(db_provider.credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
|
||||
|
||||
# check if the credential has changed, save the original credential
|
||||
for key, value in credentials.items():
|
||||
@ -131,13 +126,13 @@ class BuiltinToolManageService:
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
cache.delete()
|
||||
|
||||
# update name if provided
|
||||
if name is not None and provider.name != name:
|
||||
provider.name = name
|
||||
if name is not None and db_provider.name != name:
|
||||
db_provider.name = name
|
||||
|
||||
db.session.commit()
|
||||
except (
|
||||
@ -176,7 +171,7 @@ class BuiltinToolManageService:
|
||||
name
|
||||
if name
|
||||
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
tenant_id, provider, credential_type=api_type
|
||||
tenant_id=tenant_id, provider=provider, credential_type=api_type
|
||||
)
|
||||
)
|
||||
|
||||
@ -193,20 +188,35 @@ class BuiltinToolManageService:
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=db_provider,
|
||||
credentials=credentials,
|
||||
user_id=user_id,
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, db_provider, provider, provider_controller
|
||||
)
|
||||
|
||||
# encrypt credentials
|
||||
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
|
||||
|
||||
cache.delete()
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def create_tool_encrypter(
|
||||
tenant_id: str,
|
||||
db_provider: BuiltinToolProvider,
|
||||
provider: str,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
):
|
||||
encrypter, cache = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
|
||||
)
|
||||
return encrypter, cache
|
||||
|
||||
@staticmethod
|
||||
def generate_builtin_tool_provider_name(
|
||||
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
||||
@ -273,12 +283,13 @@ class BuiltinToolManageService:
|
||||
|
||||
default_provider.is_default = True
|
||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, default_provider, default_provider.provider, provider_controller
|
||||
)
|
||||
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
)
|
||||
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
@ -287,22 +298,24 @@ class BuiltinToolManageService:
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
||||
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if tool_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
raise ValueError(f"you have not added provider {provider}")
|
||||
|
||||
db.session.delete(tool_provider)
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
_, cache = BuiltinToolManageService.create_tool_encrypter(
|
||||
tenant_id, tool_provider, provider, provider_controller
|
||||
)
|
||||
cache.delete()
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -493,57 +506,35 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
||||
return ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_and_save_credentials(
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
tool_configuration: ProviderConfigEncrypter,
|
||||
provider: BuiltinToolProvider,
|
||||
credentials: dict,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Validate and encrypt credentials, then save to database
|
||||
|
||||
:param provider_controller: the provider controller
|
||||
:param tool_configuration: the tool configuration encrypter
|
||||
:param provider: the provider object from database
|
||||
:param credentials: the credentials to encrypt and save
|
||||
:param user_id: the user id for validation
|
||||
"""
|
||||
if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed():
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
@staticmethod
|
||||
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
|
||||
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
|
||||
"""
|
||||
setup oauth custom client
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
with Session(db.engine) as session:
|
||||
tool_provider = ToolProviderID(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
if not isinstance(provider_controller, BuiltinToolProviderController):
|
||||
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
|
||||
|
||||
# Validate and encrypt credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=None, # No need to save in DB
|
||||
credentials=client_params,
|
||||
user_id=user_id,
|
||||
)
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
|
||||
cache=NoOpProviderCredentialCache(),
|
||||
)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = encrypter.encrypt(client_params)
|
||||
session.add(
|
||||
ToolOAuthTenantClient(
|
||||
tenant_id=tenant_id,
|
||||
plugin_id=tool_provider.plugin_id,
|
||||
provider=tool_provider.provider_name,
|
||||
enabled=True,
|
||||
encrypted_oauth_params=json.dumps(encrypted_credentials),
|
||||
)
|
||||
)
|
||||
session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@ -5,6 +5,7 @@ from typing import Optional, Union, cast
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.provider_cache import ToolProviderCredentialsCache
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@ -19,7 +20,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.utils.configuration import create_encrypter, create_generic_encrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
|
||||
@ -109,7 +110,14 @@ class ToolTransformService:
|
||||
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
|
||||
|
||||
# get credentials schema
|
||||
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
|
||||
schema = {
|
||||
x.to_basic_provider_config().name: x
|
||||
for x in provider_controller.get_credentials_schema_by_type(
|
||||
ToolProviderCredentialType.of(db_provider.credential_type)
|
||||
if db_provider
|
||||
else ToolProviderCredentialType.API_KEY
|
||||
)
|
||||
}
|
||||
|
||||
for name, value in schema.items():
|
||||
if result.masked_credentials:
|
||||
@ -126,15 +134,23 @@ class ToolTransformService:
|
||||
credentials = db_provider.credentials
|
||||
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
provider_identity=provider_controller.entity.identity.name,
|
||||
config=[
|
||||
x.to_basic_provider_config()
|
||||
for x in provider_controller.get_credentials_schema_by_type(
|
||||
ToolProviderCredentialType.of(db_provider.credential_type)
|
||||
)
|
||||
],
|
||||
cache=ToolProviderCredentialsCache(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
provider=db_provider.provider,
|
||||
credential_id=db_provider.id,
|
||||
),
|
||||
)
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
result.original_credentials = decrypted_credentials
|
||||
@ -236,7 +252,7 @@ class ToolTransformService:
|
||||
|
||||
if decrypt_credentials:
|
||||
# init tool configuration
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
encrypter, _ = create_generic_encrypter(
|
||||
tenant_id=db_provider.tenant_id,
|
||||
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
|
||||
provider_type=provider_controller.provider_type.value,
|
||||
@ -244,8 +260,8 @@ class ToolTransformService:
|
||||
)
|
||||
|
||||
# decrypt the credentials and mask the credentials
|
||||
decrypted_credentials = tool_configuration.decrypt(data=credentials)
|
||||
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
|
||||
decrypted_credentials = encrypter.decrypt(data=credentials)
|
||||
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
|
||||
|
||||
result.masked_credentials = masked_credentials
|
||||
|
||||
@ -264,7 +280,7 @@ class ToolTransformService:
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials= {},
|
||||
credentials={},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user