feat(oauth): update api

This commit is contained in:
Harry
2025-06-26 11:44:00 +08:00
parent 6c9e99b0c6
commit ba843c2691
6 changed files with 84 additions and 210 deletions

View File

@ -2,6 +2,7 @@ import json
import logging
import re
from pathlib import Path
from typing import Optional, Union
from sqlalchemy import ColumnExpressionArgument
from sqlalchemy.orm import Session
@ -11,6 +12,7 @@ from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.tool_entities import ToolProviderCredentialType
@ -40,12 +42,7 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@ -53,7 +50,7 @@ class BuiltinToolManageService:
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
credentials = tool_configuration.decrypt(credentials)
result: list[ToolApiEntity] = []
for tool in tools or []:
@ -74,12 +71,7 @@ class BuiltinToolManageService:
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tool_provider_configurations = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# check if user has added the provider
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
@ -87,7 +79,7 @@ class BuiltinToolManageService:
if builtin_provider is not None:
# get credentials
credentials = builtin_provider.credentials
credentials = tool_provider_configurations.decrypt(credentials)
credentials = tool_configuration.decrypt(credentials)
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@ -100,7 +92,7 @@ class BuiltinToolManageService:
return entity
@staticmethod
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
"""
list builtin provider credentials schema
@ -123,35 +115,28 @@ class BuiltinToolManageService:
if provider is None:
raise ValueError(f"you have not added provider {provider_name}")
try:
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Decrypt and restore original credentials for masked values
original_credentials = tool_configuration.decrypt(provider.credentials)
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]: # type: ignore
credentials[name] = original_credentials[name] # type: ignore
for key, value in credentials.items():
if key in masked_credentials and value == masked_credentials[key]:
credentials[key] = original_credentials[key]
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
provider_controller, tool_configuration, provider, credentials, user_id
)
else:
raise ValueError(f"provider {provider_name} is not editable, you can only delete it and add a new one")
# update name if provided
if name is not None and provider.name != name:
@ -180,8 +165,8 @@ class BuiltinToolManageService:
"""
add builtin tool provider
"""
lock_name = f"builtin_tool_provider_credential_lock_{tenant_id}_{provider_name}_{api_type.value}"
with redis_client.lock(lock_name, timeout=20):
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
with redis_client.lock(lock, timeout=20):
if name is None:
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
@ -198,12 +183,7 @@ class BuiltinToolManageService:
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
# Encrypt and save the credentials
BuiltinToolManageService._encrypt_and_save_credentials(
@ -268,23 +248,17 @@ class BuiltinToolManageService:
return []
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
credentials: list[ToolProviderCredentialApiEntity] = []
for provider in providers:
decrypt_credential = tool_configuration.mask_tool_credentials(
tool_configuration.decrypt(provider.credentials)
)
credentials.append(
ToolTransformService.convert_builtin_provider_to_credential_api_entity(
provider=provider,
credentials=decrypt_credential,
)
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
provider=provider,
credentials=decrypt_credential,
)
credentials.append(credential_entity)
return credentials
@staticmethod
@ -292,22 +266,17 @@ class BuiltinToolManageService:
"""
delete tool provider
"""
provider_obj = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
if provider_obj is None:
if tool_provider is None:
raise ValueError(f"you have not added provider {provider_name}")
db.session.delete(provider_obj)
db.session.delete(tool_provider)
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
tool_configuration.delete_tool_credentials_cache()
return {"result": "success"}
@ -334,7 +303,9 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_oauth_client(tenant_id: str, provider: str, plugin_id: str):
def get_builtin_tool_oauth_client(
tenant_id: str, provider: str, plugin_id: str
) -> Union[ToolOAuthTenantClient, ToolOAuthSystemClient]:
"""
get builtin tool provider
"""
@ -350,14 +321,12 @@ class BuiltinToolManageService:
.first()
)
if user_client:
plugin_oauth_config = user_client
else:
plugin_oauth_config = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
return user_client
if plugin_oauth_config:
return plugin_oauth_config
raise ValueError("no oauth available config found for this plugin")
system_client = session.query(ToolOAuthSystemClient).filter_by(provider=provider).first()
if system_client is None:
raise ValueError("no oauth available client config found for this tool provider")
return system_client
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@ -379,9 +348,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
)
db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@ -432,8 +399,8 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> BuiltinToolProvider | None:
provider = (
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
provider: Optional[BuiltinToolProvider] = (
db.session.query(BuiltinToolProvider)
.filter(
BuiltinToolProvider.tenant_id == tenant_id,
@ -444,14 +411,14 @@ class BuiltinToolManageService:
return provider
@staticmethod
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
"""
This method is used to fetch the builtin provider from the database
1.if the default provider exists, return the default provider
2.if the default provider does not exist, return the oldest provider
"""
def _query(provider_filters: list[ColumnExpressionArgument[bool]]):
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
return (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
@ -484,21 +451,16 @@ class BuiltinToolManageService:
return provider
except Exception:
# it's an old provider without organization
provider_obj = _query([BuiltinToolProvider.provider == provider_name])
return provider_obj
return _query([BuiltinToolProvider.provider == provider_name])
@staticmethod
def _decrypt_and_restore_credentials(tool_configuration, provider, credentials):
"""
Decrypt original credentials and restore masked values from the input credentials
:param tool_configuration: the tool configuration encrypter
:param provider: the provider object from database
:param credentials: the input credentials from user
:return: the processed credentials with original values restored
"""
return credentials
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
return ProviderConfigEncrypter(
tenant_id=tenant_id,
config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
provider_type=provider_controller.provider_type.value,
provider_identity=provider_controller.entity.identity.name,
)
@staticmethod
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):

View File

@ -307,7 +307,7 @@ class ToolTransformService:
)
@staticmethod
def convert_builtin_provider_to_credential_api_entity(
def convert_builtin_provider_to_credential_entity(
provider: BuiltinToolProvider, credentials: dict
) -> ToolProviderCredentialApiEntity:
return ToolProviderCredentialApiEntity(