refactor(tool oauth): update api implementation

This commit is contained in:
Harry
2025-06-23 16:51:28 +08:00
parent 7f292dc261
commit 5e7c5863ef
16 changed files with 393 additions and 738 deletions

View File

@ -23,7 +23,7 @@ class OAuthProxyService(BasePluginClient):
is used to verify the state, ensuring the request's integrity and authenticity,
and mitigating replay attacks.
"""
seconds, microseconds = redis_client.time()
seconds, _ = redis_client.time()
context_id = str(uuid.uuid4())
data = {
"user_id": user_id,
@ -55,7 +55,7 @@ class OAuthProxyService(BasePluginClient):
if not data:
raise ValueError("context_id is invalid")
# check if data is expired
seconds, microseconds = redis_client.time()
seconds, _ = redis_client.time()
state = json.loads(data)
if state.get("timestamp") < seconds - max_age:
raise ValueError("context_id is expired")

View File

@ -1,20 +1,26 @@
import json
import logging
import re
from pathlib import Path
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthUserClient, ToolProviderCredentialType
from extensions.ext_redis import redis_client
from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@ -107,7 +113,7 @@ class BuiltinToolManageService:
@staticmethod
def update_builtin_tool_provider(
user_id: str, tenant_id: str, provider_name:str, credentials: dict, credential_id: str, name: str | None = None
user_id: str, tenant_id: str, provider_name: str, credentials: dict, credential_id: str, name: str | None = None
):
"""
update builtin tool provider
@ -119,7 +125,7 @@ class BuiltinToolManageService:
raise ValueError(f"you have not added provider {provider_name}")
try:
if ToolProviderCredentialType.get_credential_type(provider.credential_type).is_editable():
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
@ -132,18 +138,20 @@ class BuiltinToolManageService:
)
# Decrypt and restore original credentials for masked values
credentials = BuiltinToolManageService._decrypt_and_restore_credentials(
provider_controller, tool_configuration, provider, credentials
)
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
else:
raise ValueError(
f"provider {provider_name} is not editable, you can only delete it and add a new one"
)
raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one")
# update name if provided
if name is not None and provider.name != name:
@ -151,10 +159,10 @@ class BuiltinToolManageService:
db.session.commit()
except (
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
PluginDaemonClientSideError,
ToolProviderNotFoundError,
ToolNotFoundError,
ToolProviderCredentialValidationError,
) as e:
raise ValueError(str(e))
@ -162,94 +170,136 @@ class BuiltinToolManageService:
@staticmethod
def add_builtin_tool_provider(
user_id: str, type: ToolProviderCredentialType, tenant_id: str, provider_name:str, credentials: dict, name: str | None = None
user_id: str,
api_type: ToolProviderCredentialType,
tenant_id: str,
provider_name: str,
credentials: dict,
name: str | None = None,
):
"""
add builtin tool provider
"""
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, type)
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
credential_type=type.value,
credentials=json.dumps(credentials),
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}"
with redis_client.lock(lock_name, timeout=20):
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
db.session.add(provider)
provider = BuiltinToolProvider(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
encrypted_credentials=json.dumps(credentials),
credential_type=api_type.value,
name=name,
)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
db.session.add(provider)
db.session.commit()
return {"result": "success"}
@staticmethod
def get_next_builtin_tool_provider_name(tenant_id: str, type: ToolProviderCredentialType) -> str:
"""
next name = max(provider_names) + 1
"""
provider_names = db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
credential_type=type.value,
).all()
if not provider_names:
return f"{type.value} 1"
# OAuth 1 then OAuth 2, if don't have OAuth 1, then return OAuth 1
# if dont have number, then get name and add 1
for provider_name in provider_names:
if provider_name.provider.startswith(type.value):
return f"{type.value} {int(provider_name.provider.split(' ')[1]) + 1}"
return f"{type.value} 1"
def get_next_builtin_tool_provider_name(
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
) -> str:
try:
providers = (
db.session.query(BuiltinToolProvider)
.filter_by(
tenant_id=tenant_id,
provider=provider_name,
credential_type=type.value,
)
.order_by(BuiltinToolProvider.created_at.desc())
.limit(10)
.all()
)
# Get the default name pattern
default_pattern = type.get_name()
# Find all names that match the default pattern: "{default_pattern} {number}"
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
numbers = []
for provider in providers:
if provider.name:
match = re.match(pattern, provider.name.strip())
if match:
numbers.append(int(match.group(1)))
# If no default pattern names found, start with 1
if not numbers:
return f"{default_pattern} 1"
# Find the next number
max_number = max(numbers)
return f"{default_pattern} {max_number + 1}"
except Exception as e:
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
# fallback
return f"{type.get_name()} 1"
@staticmethod
def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
def get_builtin_tool_provider_credentials(
tenant_id: str, provider_name: str
) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
if provider_obj is None:
return {}
if len(providers) == 0:
return []
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
credentials = tool_configuration.decrypt(provider_obj.credentials)
credentials = tool_configuration.mask_tool_credentials(credentials)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
credentials.append(
ToolTransformService.convert_builtin_provider_to_credential_api_entity(
provider=provider,
credentials=decrypt_credential,
)
)
return credentials
@staticmethod
def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider_obj is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider_obj)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
@ -267,70 +317,45 @@ class BuiltinToolManageService:
"""
set default provider
"""
# get provider
target_provider = db.session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider,
default=True
).update({"default": False})
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, default=True
).update({"default": False})
# set new default provider
target_provider.default = True
db.session.commit()
# set new default provider
target_provider.is_default = True
session.commit()
return {"result": "success"}
@staticmethod
def fetch_default_provider(tenant_id: str, user_id: str, provider_name: str):
"""
fetch default provider
if there is no explicitly set default provider, return the oldest provider as default
"""
# 1. check if default provider exists
default_provider = db.session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id,
user_id=user_id,
provider=provider_name,
default=True
).first()
if default_provider:
return default_provider
# 2. if no default provider, set the oldest provider as default
oldest_provider = (db.session.query(BuiltinToolProvider)
.filter_by(tenant_id=tenant_id, user_id=user_id, provider=provider_name)
.order_by(BuiltinToolProvider.created_at)
.first()
)
if oldest_provider:
return oldest_provider
raise ValueError(f"no default provider found for {provider_name}")
@staticmethod
def get_builtin_tool_provider(tenant_id: str, user_id: str, provider: str, plugin_id: str):
"""
get builtin tool provider
"""
user_client = db.session.query(ToolOAuthUserClient).filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
).first()
with Session(db.engine) as session:
user_client = (
session.query(ToolOAuthTenantClient)
.filter_by(
tenant_id=tenant_id,
provider=provider,
plugin_id=plugin_id,
enabled=True,
)
.first()
)
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = db.session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if plugin_oauth_config:
return plugin_oauth_config
if plugin_oauth_config:
return plugin_oauth_config
raise ValueError("no oauth available config found for this plugin")
@ -408,73 +433,69 @@ class BuiltinToolManageService:
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
provider = (db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first())
provider = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.id == credential_id,
)
.first()
)
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
try:
full_provider_name = provider_name
provider_id_entity = GenericProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == full_provider_name,
)
.first()
)
else:
provider_obj = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name),
)
.first()
)
if provider_obj is None:
return None
provider_obj.provider = GenericProviderID(provider_obj.provider).to_string()
return provider_obj
except Exception:
# it's an old provider without organization
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
def _query(provider_filters: list[ColumnExpressionArgument[bool]]):
return (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
(BuiltinToolProvider.provider == provider_name),
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
.order_by(
BuiltinToolProvider.is_default.desc(), # default=True first
BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
try:
full_provider_name = provider_name
provider_id_entity = ToolProviderID(provider_name)
provider_name = provider_id_entity.provider_name
if provider_id_entity.organization != "langgenius":
provider = _query([BuiltinToolProvider.provider == full_provider_name])
else:
provider = _query(
[
(BuiltinToolProvider.provider == provider_name)
| (BuiltinToolProvider.provider == full_provider_name)
]
)
if provider is None:
return None
provider.provider = ToolProviderID(provider.provider).to_string()
return provider
except Exception:
# it's an old provider without organization
provider_obj = _query([BuiltinToolProvider.provider == provider_name])
return provider_obj
@staticmethod
def _decrypt_and_restore_credentials(provider_controller, tool_configuration, provider, credentials):
def _decrypt_and_restore_credentials(tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param provider_controller: the provider controller
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the input credentials from user
:return: the processed credentials with original values restored
"""
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
return credentials
@ -489,8 +510,9 @@ class BuiltinToolManageService:
:param credentials: the credentials to encrypt and save
:param user_id: the user id for validation
"""
# validate credentials
provider_controller.validate_credentials(user_id, credentials)
if ToolProviderCredentialType.of(provider.credential_type).is_validate_allowed():
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
encrypted_credentials = tool_configuration.encrypt(credentials)
provider.encrypted_credentials = json.dumps(encrypted_credentials)

View File

@ -9,12 +9,13 @@ from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
ToolParameter,
ToolProviderCredentialType,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
@ -304,3 +305,16 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
@staticmethod
def convert_builtin_provider_to_credential_api_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(
id=provider.id,
name=provider.name,
provider=provider.provider,
credential_type=ToolProviderCredentialType.of(provider.credential_type),
is_default=provider.is_default,
credentials=credentials,
)