mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
refactor: list tools
This commit is contained in:
@ -7,7 +7,7 @@ from core.entities.provider_entities import ProviderConfig
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -201,7 +201,7 @@ class ApiToolManageService:
|
||||
return {"schema": schema}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
|
||||
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list api tool provider tools
|
||||
"""
|
||||
@ -438,7 +438,7 @@ class ApiToolManageService:
|
||||
return {"result": result or "empty response"}
|
||||
|
||||
@staticmethod
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list api tools
|
||||
"""
|
||||
@ -447,7 +447,7 @@ class ApiToolManageService:
|
||||
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider in db_providers:
|
||||
# convert provider controller to user provider
|
||||
|
||||
@ -5,9 +5,8 @@ from pathlib import Path
|
||||
from configs import dify_config
|
||||
from core.helper.position_helper import is_filtered
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.tool_manager import ToolManager
|
||||
@ -21,11 +20,17 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class BuiltinToolManageService:
|
||||
@staticmethod
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
|
||||
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
list builtin tool provider tools
|
||||
|
||||
:param user_id: the id of the user
|
||||
:param tenant_id: the id of the tenant
|
||||
:param provider: the name of the provider
|
||||
|
||||
:return: the list of tools
|
||||
"""
|
||||
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
|
||||
tools = provider_controller.get_tools()
|
||||
|
||||
tool_provider_configurations = ProviderConfigEncrypter(
|
||||
@ -64,14 +69,16 @@ class BuiltinToolManageService:
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_provider_credentials_schema(provider_name):
|
||||
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
|
||||
"""
|
||||
list builtin provider credentials schema
|
||||
|
||||
:param provider_name: the name of the provider
|
||||
:param tenant_id: the id of the tenant
|
||||
:return: the list of tool providers
|
||||
"""
|
||||
provider = ToolManager.get_builtin_provider(provider_name)
|
||||
return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()])
|
||||
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
|
||||
|
||||
@staticmethod
|
||||
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
|
||||
@ -90,7 +97,7 @@ class BuiltinToolManageService:
|
||||
|
||||
try:
|
||||
# get provider
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
if not provider_controller.need_credentials:
|
||||
raise ValueError(f"provider {provider_name} does not need credentials")
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
@ -109,7 +116,7 @@ class BuiltinToolManageService:
|
||||
if name in masked_credentials and value == masked_credentials[name]:
|
||||
credentials[name] = original_credentials[name]
|
||||
# validate credentials
|
||||
provider_controller.validate_credentials(credentials)
|
||||
provider_controller.validate_credentials(user_id, credentials)
|
||||
# encrypt credentials
|
||||
credentials = tool_configuration.encrypt(credentials)
|
||||
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
|
||||
@ -154,7 +161,7 @@ class BuiltinToolManageService:
|
||||
if provider_obj is None:
|
||||
return {}
|
||||
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
@ -186,7 +193,7 @@ class BuiltinToolManageService:
|
||||
db.session.commit()
|
||||
|
||||
# delete cache
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name)
|
||||
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
|
||||
tool_configuration = ProviderConfigEncrypter(
|
||||
tenant_id=tenant_id,
|
||||
config=provider_controller.get_credentials_schema(),
|
||||
@ -198,22 +205,22 @@ class BuiltinToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@staticmethod
|
||||
def get_builtin_tool_provider_icon(provider: str):
|
||||
def get_builtin_tool_provider_icon(provider: str, tenant_id: str):
|
||||
"""
|
||||
get tool provider icon and it's mimetype
|
||||
"""
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
|
||||
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id)
|
||||
icon_bytes = Path(icon_path).read_bytes()
|
||||
|
||||
return icon_bytes, mime_type
|
||||
|
||||
@staticmethod
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
list builtin tools
|
||||
"""
|
||||
# get all builtin providers
|
||||
provider_controllers = ToolManager.list_builtin_providers()
|
||||
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
|
||||
|
||||
# get all user added providers
|
||||
db_providers: list[BuiltinToolProvider] = (
|
||||
@ -225,7 +232,7 @@ class BuiltinToolManageService:
|
||||
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
|
||||
)
|
||||
|
||||
result: list[UserToolProvider] = []
|
||||
result: list[ToolProviderApiEntity] = []
|
||||
|
||||
for provider_controller in provider_controllers:
|
||||
try:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
|
||||
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class ToolCommonService:
|
||||
@staticmethod
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
|
||||
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
|
||||
"""
|
||||
list tool providers
|
||||
|
||||
|
||||
@ -7,7 +7,7 @@ from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.custom_tool.provider import ApiToolProviderController
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||
from core.tools.entities.tool_entities import (
|
||||
@ -15,6 +15,7 @@ from core.tools.entities.tool_entities import (
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||
from core.tools.utils.configuration import ProviderConfigEncrypter
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
@ -44,7 +45,7 @@ class ToolTransformService:
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def repack_provider(provider: Union[dict, UserToolProvider]):
|
||||
def repack_provider(provider: Union[dict, ToolProviderApiEntity]):
|
||||
"""
|
||||
repack provider
|
||||
|
||||
@ -54,7 +55,7 @@ class ToolTransformService:
|
||||
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
|
||||
)
|
||||
elif isinstance(provider, UserToolProvider):
|
||||
elif isinstance(provider, ToolProviderApiEntity):
|
||||
provider.icon = ToolTransformService.get_tool_provider_icon_url(
|
||||
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
|
||||
)
|
||||
@ -62,14 +63,14 @@ class ToolTransformService:
|
||||
@classmethod
|
||||
def builtin_provider_to_user_provider(
|
||||
cls,
|
||||
provider_controller: BuiltinToolProviderController,
|
||||
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
|
||||
db_provider: Optional[BuiltinToolProvider],
|
||||
decrypt_credentials: bool = True,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=provider_controller.entity.identity.name,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
@ -154,7 +155,7 @@ class ToolTransformService:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
return UserToolProvider(
|
||||
return ToolProviderApiEntity(
|
||||
id=provider_controller.provider_id,
|
||||
author=provider_controller.entity.identity.author,
|
||||
name=provider_controller.entity.identity.name,
|
||||
@ -181,7 +182,7 @@ class ToolTransformService:
|
||||
db_provider: ApiToolProvider,
|
||||
decrypt_credentials: bool = True,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserToolProvider:
|
||||
) -> ToolProviderApiEntity:
|
||||
"""
|
||||
convert provider controller to user provider
|
||||
"""
|
||||
@ -197,7 +198,7 @@ class ToolTransformService:
|
||||
|
||||
# add provider into providers
|
||||
credentials = db_provider.credentials
|
||||
result = UserToolProvider(
|
||||
result = ToolProviderApiEntity(
|
||||
id=db_provider.id,
|
||||
author=username,
|
||||
name=db_provider.name,
|
||||
@ -240,7 +241,7 @@ class ToolTransformService:
|
||||
tenant_id: str,
|
||||
credentials: dict | None = None,
|
||||
labels: list[str] | None = None,
|
||||
) -> UserTool:
|
||||
) -> ToolApiEntity:
|
||||
"""
|
||||
convert tool to user tool
|
||||
"""
|
||||
@ -248,7 +249,7 @@ class ToolTransformService:
|
||||
# fork tool runtime
|
||||
tool = tool.fork_tool_runtime(
|
||||
runtime=ToolRuntime(
|
||||
credentials=credentials,
|
||||
credentials=credentials or {},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
@ -270,7 +271,7 @@ class ToolTransformService:
|
||||
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||
current_parameters.append(runtime_parameter)
|
||||
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.entity.identity.author,
|
||||
name=tool.entity.identity.name,
|
||||
label=tool.entity.identity.label,
|
||||
@ -279,7 +280,7 @@ class ToolTransformService:
|
||||
labels=labels or [],
|
||||
)
|
||||
if isinstance(tool, ApiToolBundle):
|
||||
return UserTool(
|
||||
return ToolApiEntity(
|
||||
author=tool.author,
|
||||
name=tool.operation_id,
|
||||
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),
|
||||
|
||||
@ -4,7 +4,7 @@ from datetime import datetime
|
||||
from sqlalchemy import or_
|
||||
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.api_entities import UserTool, UserToolProvider
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.tool_label_manager import ToolLabelManager
|
||||
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
|
||||
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
|
||||
@ -183,7 +183,7 @@ class WorkflowToolManageService:
|
||||
return {"result": "success"}
|
||||
|
||||
@classmethod
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
|
||||
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
|
||||
"""
|
||||
List workflow tools.
|
||||
:param user_id: the user id
|
||||
@ -309,7 +309,7 @@ class WorkflowToolManageService:
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
|
||||
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
|
||||
"""
|
||||
List workflow tool provider tools.
|
||||
:param user_id: the user id
|
||||
|
||||
Reference in New Issue
Block a user