feat(oauth): refactor tool provider methods and enhance credential handling

This commit is contained in:
Harry
2025-06-27 13:17:09 +08:00
parent 8a954c0b19
commit daec82bd44
9 changed files with 309 additions and 170 deletions

View File

@ -446,7 +446,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@ -474,7 +474,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
tenant_id=tenant_id, tool=tool, labels=labels
)
)

View File

@ -4,7 +4,6 @@ import re
from pathlib import Path
from typing import Optional, Union
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session
from configs import dify_config
@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@ -42,22 +45,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_configuration.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@ -73,7 +65,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
credentials = {}
if builtin_provider is not None:
@ -92,16 +84,19 @@ class BuiltinToolManageService:
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
def list_builtin_provider_credentials_schema(
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str
):
"""
list builtin provider credentials schema
:param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder(provider.get_credentials_schema())
return jsonable_encoder(provider.get_credentials_schema(credential_type))
@staticmethod
def update_builtin_tool_provider(
@ -111,11 +106,11 @@ class BuiltinToolManageService:
update builtin tool provider
"""
# get if the provider exists
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
@ -133,10 +128,12 @@ class BuiltinToolManageService:
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
# update name if provided
if name is not None and provider.name != name:
@ -158,68 +155,84 @@ class BuiltinToolManageService:
user_id: str,
api_type: ToolProviderCredentialType,
tenant_id: str,
provider_name: str,
provider: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
with redis_client.lock(lock, timeout=20):
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
# check if the provider count is over the limit
provider_count = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
)
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
raise ValueError(f"you have reached the maximum number of providers for {provider}")
provider = BuiltinToolProvider(
# TODO should we get name from oauth authentication?
name = (
name
if name
else BuiltinToolManageService.generate_builtin_tool_provider_name(
tenant_id, provider, credential_type=api_type
)
)
db_provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
provider=provider,
encrypted_credentials=json.dumps(credentials),
credential_type=api_type.value,
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
raise ValueError(f"provider {provider} does not need credentials")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=db_provider,
credentials=credentials,
user_id=user_id,
)
db.session.add(provider)
db.session.add(db_provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_next_builtin_tool_provider_name(
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
def generate_builtin_tool_provider_name(
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
) -> str:
try:
providers = (
db_providers = (
db.session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_name,
credential_type=type.value,
provider=provider,
credential_type=credential_type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.limit(10)
.all()
)
# Get the default name pattern
default_pattern = type.get_name()
default_pattern = f"{credential_type.get_name()}"
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
for db_provider in db_providers:
if db_provider.name:
match = re.match(pattern, db_provider.name.strip())
if match:
numbers.append(int(match.group(1)))
@ -231,9 +244,9 @@ class BuiltinToolManageService:
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
# fallback
return f"{type.get_name()} 1"
return f"{credential_type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(
@ -242,31 +255,43 @@ class BuiltinToolManageService:
"""
get builtin tool provider credentials
"""
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
if len(providers) == 0:
return []
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
with db.session.no_autoflush:
providers = (
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
if len(providers) == 0:
return []
default_provider = sorted(
providers,
key=lambda p: (
not getattr(p, "is_default", False),
getattr(p, "created_at", None) or 0,
),
)[0]
default_provider.is_default = True
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
"""
delete tool provider
"""
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}")
@ -387,7 +412,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@ -399,7 +423,7 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
@ -411,47 +435,62 @@ class BuiltinToolManageService:
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
with Session(db.engine) as session:
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
return (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
if provider_id_entity.organization != "langgenius":
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
else:
provider = (
session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return (
session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
.first()
)
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = _query([BuiltinToolProvider.provider == full_provider_name])
else:
provider = _query(
[
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name)
]
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
return _query([BuiltinToolProvider.provider == provider_name])
@staticmethod
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
@ -463,7 +502,13 @@ class BuiltinToolManageService:
)
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
def _encrypt_and_save_credentials(
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
tool_configuration: ProviderConfigEncrypter,
provider: BuiltinToolProvider,
credentials: dict,
user_id: str,
):
"""
Validate and encrypt credentials, then save to database
@ -480,3 +525,25 @@ class BuiltinToolManageService:
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)
tool_configuration.delete_tool_credentials_cache()
@staticmethod
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
"""
setup oauth custom client
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
if not provider_controller:
raise ToolProviderNotFoundError(f"Provider {provider} not found")
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Validate and encrypt credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller=provider_controller,
tool_configuration=tool_configuration,
provider=None, # No need to save in DB
credentials=client_params,
user_id=user_id,
)
return {"result": "success"}

View File

@ -255,7 +255,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@ -265,7 +264,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials or {},
credentials= {},
tenant_id=tenant_id,
)
)