refactor(tool): implement multi provider credentials support

This commit is contained in:
Harry
2025-07-02 10:04:57 +08:00
parent daec82bd44
commit 7951a1c4df
10 changed files with 296 additions and 218 deletions

View File

@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.configuration import ProviderConfigEncrypter, create_generic_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@ -297,28 +297,28 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
tool_configuration = ProviderConfigEncrypter(
encrypter, cache = create_generic_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
original_credentials = encrypter.decrypt(provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
credentials = tool_configuration.encrypt(credentials)
credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
tool_configuration.delete_tool_credentials_cache()
cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@ -416,15 +416,15 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_generic_encrypter(
tenant_id=tenant_id,
config=list(provider_controller.get_credentials_schema()),
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
decrypted_credentials = tool_configuration.decrypt(credentials)
decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]

View File

@ -8,19 +8,18 @@ from sqlalchemy.orm import Session
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.configuration import create_encrypter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
@ -58,20 +57,15 @@ class BuiltinToolManageService:
return result
@staticmethod
def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_configuration.decrypt(credentials)
if builtin_provider is None:
raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@ -80,7 +74,6 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
return entity
@staticmethod
@ -96,32 +89,34 @@ class BuiltinToolManageService:
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema(credential_type))
return jsonable_encoder(provider.get_credentials_schema_by_type(credential_type))
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None
user_id: str, tenant_id: str, provider: str, credentials: dict, credential_id: str, name: str | None = None
):
"""
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
db_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
if db_provider is None:
raise ValueError(f"you have not added provider {provider}")
try:
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if ToolProviderCredentialType.of(db_provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
raise ValueError(f"provider {provider} does not need credentials")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
# Decrypt and restore original credentials for masked values
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
original_credentials = encrypter.decrypt(db_provider.credentials)
masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for key, value in credentials.items():
@ -131,13 +126,13 @@ class BuiltinToolManageService:
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete()
# update name if provided
if name is not None and provider.name != name:
provider.name = name
if name is not None and db_provider.name != name:
db_provider.name = name
db.session.commit()
except (
@ -176,7 +171,7 @@ class BuiltinToolManageService:
name
if name
else BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id, provider, credential_type=api_type
tenant_id=tenant_id, provider=provider, credential_type=api_type
)
)
@ -193,20 +188,35 @@ class BuiltinToolManageService:
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider} does not need credentials")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=db_provider,
credentials=credentials,
user_id=user_id,
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, db_provider, provider, provider_controller
)
# encrypt credentials
db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(credentials))
cache.delete()
db.session.add(db_provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def create_tool_encrypter(
tenant_id: str,
db_provider: BuiltinToolProvider,
provider: str,
provider_controller: BuiltinToolProviderController,
):
encrypter, cache = create_encrypter(
tenant_id=tenant_id,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
],
cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
)
return encrypter, cache
@staticmethod
def generate_builtin_tool_provider_name(
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
@ -273,12 +283,13 @@ class BuiltinToolManageService:
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, default_provider, default_provider.provider, provider_controller
)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
@ -287,22 +298,24 @@ class BuiltinToolManageService:
return credentials
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}")
raise ValueError(f"you have not added provider {provider}")
db.session.delete(tool_provider)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tool_configuration.delete_tool_credentials_cache()
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
_, cache = BuiltinToolManageService.create_tool_encrypter(
tenant_id, tool_provider, provider, provider_controller
)
cache.delete()
return {"result": "success"}
@ -493,57 +506,35 @@ class BuiltinToolManageService:
)
@staticmethod
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
return ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@staticmethod
def _encrypt_and_save_credentials(
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
tool_configuration: ProviderConfigEncrypter,
provider: BuiltinToolProvider,
credentials: dict,
user_id: str,
):
"""
Validate and encrypt credentials, then save to database
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the credentials to encrypt and save
:param user_id: the user id for validation
"""
if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
@staticmethod
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
def setup_oauth_custom_client(tenant_id: str, provider: str, client_params: dict):
"""
setup oauth custom client
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
with Session(db.engine) as session:
tool_provider = ToolProviderID(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
if not isinstance(provider_controller, BuiltinToolProviderController):
raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
# Validate and encrypt credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=None, # No need to save in DB
credentials=client_params,
user_id=user_id,
)
encrypter, _ = create_encrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
cache=NoOpProviderCredentialCache(),
)
# encrypt credentials
encrypted_credentials = encrypter.encrypt(client_params)
session.add(
ToolOAuthTenantClient(
tenant_id=tenant_id,
plugin_id=tool_provider.plugin_id,
provider=tool_provider.provider_name,
enabled=True,
encrypted_oauth_params=json.dumps(encrypted_credentials),
)
)
session.commit()
return {"result": "success"}

View File

@ -5,6 +5,7 @@ from typing import Optional, Union, cast
from yarl import URL
from configs import dify_config
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@ -19,7 +20,7 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.utils.configuration import create_encrypter, create_generic_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
@ -109,7 +110,14 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
schema = {
x.to_basic_provider_config().name: x
for x in provider_controller.get_credentials_schema_by_type(
ToolProviderCredentialType.of(db_provider.credential_type)
if db_provider
else ToolProviderCredentialType.API_KEY
)
}
for name, value in schema.items():
if result.masked_credentials:
@ -126,15 +134,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
config=[
x.to_basic_provider_config()
for x in provider_controller.get_credentials_schema_by_type(
ToolProviderCredentialType.of(db_provider.credential_type)
)
],
cache=ToolProviderCredentialsCache(
tenant_id=db_provider.tenant_id,
provider=db_provider.provider,
credential_id=db_provider.id,
),
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@ -236,7 +252,7 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
tool_configuration = ProviderConfigEncrypter(
encrypter, _ = create_generic_encrypter(
tenant_id=db_provider.tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
@ -244,8 +260,8 @@ class ToolTransformService:
)
# decrypt the credentials and mask the credentials
decrypted_credentials = tool_configuration.decrypt(data=credentials)
masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
decrypted_credentials = encrypter.decrypt(data=credentials)
masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@ -264,7 +280,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials= {},
credentials={},
tenant_id=tenant_id,
)
)