refactor: list tools

This commit is contained in:
Yeuoly
2024-09-23 18:06:16 +08:00
parent 435e71eb60
commit 7a3e756020
26 changed files with 365 additions and 139 deletions

View File

@ -7,7 +7,7 @@ from core.entities.provider_entities import ProviderConfig
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
@ -201,7 +201,7 @@ class ApiToolManageService:
return {"schema": schema}
@staticmethod
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[UserTool]:
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider_name: str) -> list[ToolApiEntity]:
"""
list api tool provider tools
"""
@ -438,7 +438,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@ -447,7 +447,7 @@ class ApiToolManageService:
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[UserToolProvider] = []
result: list[ToolProviderApiEntity] = []
for provider in db_providers:
# convert provider controller to user provider

View File

@ -5,9 +5,8 @@ from pathlib import Path
from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
@ -21,11 +20,17 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
@staticmethod
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]:
def list_builtin_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
list builtin tool provider tools
:param user_id: the id of the user
:param tenant_id: the id of the tenant
:param provider: the name of the provider
:return: the list of tools
"""
provider_controller: ToolProviderController = ToolManager.get_builtin_provider(provider)
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
tool_provider_configurations = ProviderConfigEncrypter(
@ -64,14 +69,16 @@ class BuiltinToolManageService:
return result
@staticmethod
def list_builtin_provider_credentials_schema(provider_name):
def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
"""
list builtin provider credentials schema
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name)
return jsonable_encoder([v for _, v in (provider.entity.credentials_schema or {}).items()])
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
return jsonable_encoder([v for _, v in (provider.get_credentials_schema() or {}).items()])
@staticmethod
def update_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str, credentials: dict):
@ -90,7 +97,7 @@ class BuiltinToolManageService:
try:
# get provider
provider_controller = ToolManager.get_builtin_provider(provider_name)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
if not provider_controller.need_credentials:
raise ValueError(f"provider {provider_name} does not need credentials")
tool_configuration = ProviderConfigEncrypter(
@ -109,7 +116,7 @@ class BuiltinToolManageService:
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
# validate credentials
provider_controller.validate_credentials(credentials)
provider_controller.validate_credentials(user_id, credentials)
# encrypt credentials
credentials = tool_configuration.encrypt(credentials)
except (ToolProviderNotFoundError, ToolNotFoundError, ToolProviderCredentialValidationError) as e:
@ -154,7 +161,7 @@ class BuiltinToolManageService:
if provider_obj is None:
return {}
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider)
provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
@ -186,7 +193,7 @@ class BuiltinToolManageService:
db.session.commit()
# delete cache
provider_controller = ToolManager.get_builtin_provider(provider_name)
provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=provider_controller.get_credentials_schema(),
@ -198,22 +205,22 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
def get_builtin_tool_provider_icon(provider: str, tenant_id: str):
"""
get tool provider icon and it's mimetype
"""
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider)
icon_path, mime_type = ToolManager.get_builtin_provider_icon(provider, tenant_id)
icon_bytes = Path(icon_path).read_bytes()
return icon_bytes, mime_type
@staticmethod
def list_builtin_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]:
def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list builtin tools
"""
# get all builtin providers
provider_controllers = ToolManager.list_builtin_providers()
provider_controllers = ToolManager.list_builtin_providers(tenant_id)
# get all user added providers
db_providers: list[BuiltinToolProvider] = (
@ -225,7 +232,7 @@ class BuiltinToolManageService:
filter(lambda db_provider: db_provider.provider == provider, db_providers), None
)
result: list[UserToolProvider] = []
result: list[ToolProviderApiEntity] = []
for provider_controller in provider_controllers:
try:

View File

@ -1,6 +1,6 @@
import logging
from core.tools.entities.api_entities import UserToolProviderTypeLiteral
from core.tools.entities.api_entities import ToolProviderTypeApiLiteral
from core.tools.tool_manager import ToolManager
from services.tools.tools_transform_service import ToolTransformService
@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
class ToolCommonService:
@staticmethod
def list_tool_providers(user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral = None):
def list_tool_providers(user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral = None):
"""
list tool providers

View File

@ -7,7 +7,7 @@ from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
@ -15,6 +15,7 @@ from core.tools.entities.tool_entities import (
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.utils.configuration import ProviderConfigEncrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
@ -44,7 +45,7 @@ class ToolTransformService:
return ""
@staticmethod
def repack_provider(provider: Union[dict, UserToolProvider]):
def repack_provider(provider: Union[dict, ToolProviderApiEntity]):
"""
repack provider
@ -54,7 +55,7 @@ class ToolTransformService:
provider["icon"] = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider["type"], provider_name=provider["name"], icon=provider["icon"]
)
elif isinstance(provider, UserToolProvider):
elif isinstance(provider, ToolProviderApiEntity):
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
@ -62,14 +63,14 @@ class ToolTransformService:
@classmethod
def builtin_provider_to_user_provider(
cls,
provider_controller: BuiltinToolProviderController,
provider_controller: BuiltinToolProviderController | PluginToolProviderController,
db_provider: Optional[BuiltinToolProvider],
decrypt_credentials: bool = True,
) -> UserToolProvider:
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
result = UserToolProvider(
result = ToolProviderApiEntity(
id=provider_controller.entity.identity.name,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
@ -154,7 +155,7 @@ class ToolTransformService:
"""
convert provider controller to user provider
"""
return UserToolProvider(
return ToolProviderApiEntity(
id=provider_controller.provider_id,
author=provider_controller.entity.identity.author,
name=provider_controller.entity.identity.name,
@ -181,7 +182,7 @@ class ToolTransformService:
db_provider: ApiToolProvider,
decrypt_credentials: bool = True,
labels: list[str] | None = None,
) -> UserToolProvider:
) -> ToolProviderApiEntity:
"""
convert provider controller to user provider
"""
@ -197,7 +198,7 @@ class ToolTransformService:
# add provider into providers
credentials = db_provider.credentials
result = UserToolProvider(
result = ToolProviderApiEntity(
id=db_provider.id,
author=username,
name=db_provider.name,
@ -240,7 +241,7 @@ class ToolTransformService:
tenant_id: str,
credentials: dict | None = None,
labels: list[str] | None = None,
) -> UserTool:
) -> ToolApiEntity:
"""
convert tool to user tool
"""
@ -248,7 +249,7 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
credentials=credentials,
credentials=credentials or {},
tenant_id=tenant_id,
)
)
@ -270,7 +271,7 @@ class ToolTransformService:
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
current_parameters.append(runtime_parameter)
return UserTool(
return ToolApiEntity(
author=tool.entity.identity.author,
name=tool.entity.identity.name,
label=tool.entity.identity.label,
@ -279,7 +280,7 @@ class ToolTransformService:
labels=labels or [],
)
if isinstance(tool, ApiToolBundle):
return UserTool(
return ToolApiEntity(
author=tool.author,
name=tool.operation_id,
label=I18nObject(en_US=tool.operation_id, zh_Hans=tool.operation_id),

View File

@ -4,7 +4,7 @@ from datetime import datetime
from sqlalchemy import or_
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.api_entities import UserTool, UserToolProvider
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
@ -183,7 +183,7 @@ class WorkflowToolManageService:
return {"result": "success"}
@classmethod
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[UserToolProvider]:
def list_tenant_workflow_tools(cls, user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
"""
List workflow tools.
:param user_id: the user id
@ -309,7 +309,7 @@ class WorkflowToolManageService:
}
@classmethod
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[UserTool]:
def list_single_workflow_tools(cls, user_id: str, tenant_id: str, workflow_tool_id: str) -> list[ToolApiEntity]:
"""
List workflow tool provider tools.
:param user_id: the user id