mirror of
https://github.com/langgenius/dify.git
synced 2026-04-25 21:26:15 +08:00
feat(oauth): refactor tool provider methods and enhance credential handling
This commit is contained in:
@ -446,7 +446,7 @@ class ApiToolManageService:
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
@ -474,7 +474,7 @@ class ApiToolManageService:
|
||||
for tool in tools or []:
|
||||
user_provider.tools.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
|
||||
tenant_id=tenant_id, tool=tool, labels=labels
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from sqlalchemy import ColumnExpressionArgument
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
@ -13,10 +12,12 @@ from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import ToolProviderID
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
|
||||
from core.tools.entities.tool_entities import ToolProviderCredentialType
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
@ -29,6 +30,8 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BuiltinToolManageService:
|
||||
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
@ -42,22 +45,11 @@ class BuiltinToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
# get credentials
|
||||
credentials = builtin_provider.credentials
|
||||
credentials = tool_configuration.decrypt(credentials)
|
||||
|
||||
result: list[ToolApiEntity] = []
|
||||
for tool in tools or []:
|
||||
result.append(
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tool=tool,
|
||||
credentials=credentials,
|
||||
tenant_id=tenant_id,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
@ -73,7 +65,7 @@ class BuiltinToolManageService:
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
# check if user has added the provider
|
||||
builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
|
||||
builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
|
||||
|
||||
credentials = {}
|
||||
if builtin_provider is not None:
|
||||
@ -92,16 +84,19 @@ class BuiltinToolManageService:
|
||||
return entity
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, credential_type: str, tenant_id: str):
|
||||
def list_builtin_provider_credentials_schema(
|
||||
provider_name: str, credential_type: ToolProviderCredentialType, tenant_id: str
|
||||
):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param credential_type: credential type
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder(provider.get_credentials_schema())
|
||||
return jsonable_encoder(provider.get_credentials_schema(credential_type))
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(
|
||||
@ -111,11 +106,11 @@ class BuiltinToolManageService:
|
||||
update builtin tool provider
|
||||
"""
|
||||
# get if the provider exists
|
||||
provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
|
||||
|
||||
try:
|
||||
if ToolProviderCredentialType.of(provider.credential_type).is_editable():
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
@ -133,10 +128,12 @@ class BuiltinToolManageService:
|
||||
if key in masked_credentials and value == masked_credentials[key]:
|
||||
credentials[key] = original_credentials[key]
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
|
||||
# encrypt credentials
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
# update name if provided
|
||||
if name is not None and provider.name != name:
|
||||
@ -158,68 +155,84 @@ class BuiltinToolManageService:
|
||||
user_id: str,
|
||||
api_type: ToolProviderCredentialType,
|
||||
tenant_id: str,
|
||||
provider_name: str,
|
||||
provider: str,
|
||||
credentials: dict,
|
||||
name: str | None = None,
|
||||
):
|
||||
"""
|
||||
add builtin tool provider
|
||||
"""
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider_name}"
|
||||
lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
|
||||
with redis_client.lock(lock, timeout=20):
|
||||
if name is None:
|
||||
name = BuiltinToolManageService.get_next_builtin_tool_provider_name(tenant_id, provider_name, api_type)
|
||||
# check if the provider count is over the limit
|
||||
provider_count = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
|
||||
)
|
||||
if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
|
||||
raise ValueError(f"you have reached the maximum number of providers for {provider}")
|
||||
|
||||
provider = BuiltinToolProvider(
|
||||
# TODO should we get name from oauth authentication?
|
||||
name = (
|
||||
name
|
||||
if name
|
||||
else BuiltinToolManageService.generate_builtin_tool_provider_name(
|
||||
tenant_id, provider, credential_type=api_type
|
||||
)
|
||||
)
|
||||
|
||||
db_provider = BuiltinToolProvider(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
provider=provider,
|
||||
encrypted_credentials=json.dumps(credentials),
|
||||
credential_type=api_type.value,
|
||||
name=name,
|
||||
)
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
raise ValueError(f"provider {provider} does not need credentials")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Encrypt and save the credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller, tool_configuration, provider, credentials, user_id
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=db_provider,
|
||||
credentials=credentials,
|
||||
user_id=user_id,
|
||||
)
|
||||
db.session.add(provider)
|
||||
db.session.add(db_provider)
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_next_builtin_tool_provider_name(
|
||||
tenant_id: str, provider_name: str, type: ToolProviderCredentialType
|
||||
def generate_builtin_tool_provider_name(
|
||||
tenant_id: str, provider: str, credential_type: ToolProviderCredentialType
|
||||
) -> str:
|
||||
try:
|
||||
providers = (
|
||||
db_providers = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter_by(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
credential_type=type.value,
|
||||
provider=provider,
|
||||
credential_type=credential_type.value,
|
||||
)
|
||||
.order_by(BuiltinToolProvider.created_at.desc())
|
||||
.limit(10)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get the default name pattern
|
||||
default_pattern = type.get_name()
|
||||
default_pattern = f"{credential_type.get_name()}"
|
||||
|
||||
# Find all names that match the default pattern: "{default_pattern} {number}"
|
||||
pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
|
||||
numbers = []
|
||||
|
||||
for provider in providers:
|
||||
if provider.name:
|
||||
match = re.match(pattern, provider.name.strip())
|
||||
for db_provider in db_providers:
|
||||
if db_provider.name:
|
||||
match = re.match(pattern, db_provider.name.strip())
|
||||
if match:
|
||||
numbers.append(int(match.group(1)))
|
||||
|
||||
@ -231,9 +244,9 @@ class BuiltinToolManageService:
|
||||
max_number = max(numbers)
|
||||
return f"{default_pattern} {max_number + 1}"
|
||||
except Exception as e:
|
||||
logger.warning(f"Error generating next provider name for {provider_name}: {str(e)}")
|
||||
logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
|
||||
# fallback
|
||||
return f"{type.get_name()} 1"
|
||||
return f"{credential_type.get_name()} 1"
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_credentials(
|
||||
@ -242,31 +255,43 @@ class BuiltinToolManageService:
|
||||
"""
|
||||
get builtin tool provider credentials
|
||||
"""
|
||||
providers = db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(providers[0].provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
with db.session.no_autoflush:
|
||||
providers = (
|
||||
db.session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider_name).all()
|
||||
)
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
if len(providers) == 0:
|
||||
return []
|
||||
|
||||
default_provider = sorted(
|
||||
providers,
|
||||
key=lambda p: (
|
||||
not getattr(p, "is_default", False),
|
||||
getattr(p, "created_at", None) or 0,
|
||||
),
|
||||
)[0]
|
||||
|
||||
default_provider.is_default = True
|
||||
provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
credentials: list[ToolProviderCredentialApiEntity] = []
|
||||
for provider in providers:
|
||||
decrypt_credential = tool_configuration.mask_tool_credentials(
|
||||
tool_configuration.decrypt(provider.credentials)
|
||||
)
|
||||
credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
|
||||
provider=provider,
|
||||
credentials=decrypt_credential,
|
||||
)
|
||||
credentials.append(credential_entity)
|
||||
return credentials
|
||||
|
||||
@staticmethod
|
||||
def delete_builtin_tool_provider(tenant_id: str, provider_name: str, credential_id: str):
|
||||
"""
|
||||
delete tool provider
|
||||
"""
|
||||
tool_provider = BuiltinToolManageService._fetch_builtin_provider_by_id(tenant_id, credential_id)
|
||||
tool_provider = BuiltinToolManageService.get_builtin_provider_by_id(tenant_id, credential_id)
|
||||
|
||||
if tool_provider is None:
|
||||
raise ValueError(f"you have not added provider {provider_name}")
|
||||
@ -387,7 +412,6 @@ class BuiltinToolManageService:
|
||||
ToolTransformService.convert_tool_entity_to_api_entity(
|
||||
tenant_id=tenant_id,
|
||||
tool=tool,
|
||||
credentials=user_builtin_provider.original_credentials,
|
||||
labels=ToolLabelManager.get_tool_labels(provider_controller),
|
||||
)
|
||||
)
|
||||
@ -399,7 +423,7 @@ class BuiltinToolManageService:
|
||||
return BuiltinToolProviderSort.sort(result)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider_by_id(tenant_id: str, credential_id: str) -> Optional[BuiltinToolProvider]:
|
||||
provider: Optional[BuiltinToolProvider] = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
@ -411,47 +435,62 @@ class BuiltinToolManageService:
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
|
||||
"""
|
||||
This method is used to fetch the builtin provider from the database
|
||||
1.if the default provider exists, return the default provider
|
||||
2.if the default provider does not exist, return the oldest provider
|
||||
"""
|
||||
with Session(db.engine) as session:
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
def _query(provider_filters: list[ColumnExpressionArgument[bool]]) -> Optional[BuiltinToolProvider]:
|
||||
return (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, *provider_filters)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
BuiltinToolProvider.provider == full_provider_name,
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
provider = (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(
|
||||
BuiltinToolProvider.tenant_id == tenant_id,
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name),
|
||||
)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return None
|
||||
|
||||
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return (
|
||||
session.query(BuiltinToolProvider)
|
||||
.filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
|
||||
.order_by(
|
||||
BuiltinToolProvider.is_default.desc(), # default=True first
|
||||
BuiltinToolProvider.created_at.asc(), # oldest first
|
||||
)
|
||||
.first()
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
try:
|
||||
full_provider_name = provider_name
|
||||
provider_id_entity = ToolProviderID(provider_name)
|
||||
provider_name = provider_id_entity.provider_name
|
||||
|
||||
if provider_id_entity.organization != "langgenius":
|
||||
provider = _query([BuiltinToolProvider.provider == full_provider_name])
|
||||
else:
|
||||
provider = _query(
|
||||
[
|
||||
(BuiltinToolProvider.provider == provider_name)
|
||||
| (BuiltinToolProvider.provider == full_provider_name)
|
||||
]
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
return None
|
||||
|
||||
provider.provider = ToolProviderID(provider.provider).to_string()
|
||||
return provider
|
||||
except Exception:
|
||||
# it's an old provider without organization
|
||||
return _query([BuiltinToolProvider.provider == provider_name])
|
||||
|
||||
@staticmethod
|
||||
def _create_tool_configuration(tenant_id: str, provider_controller: ToolProviderController):
|
||||
@ -463,7 +502,13 @@ class BuiltinToolManageService:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _encrypt_and_save_credentials(provider_controller, tool_configuration, provider, credentials, user_id):
|
||||
def _encrypt_and_save_credentials(
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
tool_configuration: ProviderConfigEncrypter,
|
||||
provider: BuiltinToolProvider,
|
||||
credentials: dict,
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Validate and encrypt credentials, then save to database
|
||||
|
||||
@ -480,3 +525,25 @@ class BuiltinToolManageService:
|
||||
encrypted_credentials = tool_configuration.encrypt(credentials)
|
||||
provider.encrypted_credentials = json.dumps(encrypted_credentials)
|
||||
tool_configuration.delete_tool_credentials_cache()
|
||||
|
||||
@staticmethod
|
||||
def setup_oauth_custom_client(tenant_id: str, user_id: str, provider: str, client_params: dict):
|
||||
"""
|
||||
setup oauth custom client
|
||||
"""
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
if not provider_controller:
|
||||
raise ToolProviderNotFoundError(f"Provider {provider} not found")
|
||||
|
||||
tool_configuration = BuiltinToolManageService._create_tool_configuration(tenant_id, provider_controller)
|
||||
|
||||
# Validate and encrypt credentials
|
||||
BuiltinToolManageService._encrypt_and_save_credentials(
|
||||
provider_controller=provider_controller,
|
||||
tool_configuration=tool_configuration,
|
||||
provider=None, # No need to save in DB
|
||||
credentials=client_params,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {"result": "success"}
|
||||
|
||||
@ -255,7 +255,6 @@ class ToolTransformService:
|
||||
def convert_tool_entity_to_api_entity(
|
||||
tool: Union[ApiToolBundle, WorkflowTool, Tool],
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
@ -265,7 +264,7 @@ class ToolTransformService:
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials or {},
|
||||
credentials= {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user