refactor: list tools

This commit is contained in:
Yeuoly
2024-09-23 18:06:16 +08:00
parent 435e71eb60
commit 7a3e756020
26 changed files with 365 additions and 139 deletions

View File

@ -6,7 +6,10 @@ from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Union, cast
from core.plugin.manager.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
@ -24,7 +27,7 @@ from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.custom_tool.provider import ApiToolProviderController
from core.tools.custom_tool.tool import ApiTool
from core.tools.entities.api_entities import UserToolProvider, UserToolProviderTypeLiteral
from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProviderTypeApiLiteral
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ApiProviderAuthType, ToolInvokeFrom, ToolParameter, ToolProviderType
from core.tools.errors import ToolProviderNotFoundError
@ -41,38 +44,61 @@ logger = logging.getLogger(__name__)
class ToolManager:
_builtin_provider_lock = Lock()
_builtin_providers = {}
_hardcoded_providers = {}
_builtin_providers_loaded = False
_builtin_tools_labels = {}
@classmethod
def get_builtin_provider(cls, provider: str) -> BuiltinToolProviderController:
def get_builtin_provider(
cls, provider: str, tenant_id: str
) -> BuiltinToolProviderController | PluginToolProviderController:
"""
get the builtin provider
:param provider: the name of the provider
:param tenant_id: the id of the tenant
:return: the provider
"""
if len(cls._builtin_providers) == 0:
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_builtin_providers_cache()
cls.load_hardcoded_providers_cache()
if provider not in cls._builtin_providers:
raise ToolProviderNotFoundError(f"builtin provider {provider} not found")
if provider not in cls._hardcoded_providers:
# get plugin provider
plugin_provider = cls.get_plugin_provider(provider, tenant_id)
if plugin_provider:
return plugin_provider
return cls._builtin_providers[provider]
return cls._hardcoded_providers[provider]
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str) -> BuiltinTool | None:
def get_plugin_provider(cls, provider: str, tenant_id: str) -> PluginToolProviderController:
"""
get the plugin provider
"""
manager = PluginToolManager()
providers = manager.fetch_tool_providers(tenant_id)
provider_entity = next((x for x in providers if x.declaration.identity.name == provider), None)
if not provider_entity:
raise ToolProviderNotFoundError(f"plugin provider {provider} not found")
return PluginToolProviderController(
entity=provider_entity.declaration,
tenant_id=tenant_id,
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
)
@classmethod
def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
"""
get the builtin tool
:param provider: the name of the provider
:param tool_name: the name of the tool
:param tenant_id: the id of the tenant
:return: the provider, the tool
"""
provider_controller = cls.get_builtin_provider(provider)
provider_controller = cls.get_builtin_provider(provider, tenant_id)
tool = provider_controller.get_tool(tool_name)
return tool
@ -97,12 +123,12 @@ class ToolManager:
:return: the tool
"""
if provider_type == ToolProviderType.BUILT_IN:
builtin_tool = cls.get_builtin_tool(provider_id, tool_name)
builtin_tool = cls.get_builtin_tool(provider_id, tool_name, tenant_id)
if not builtin_tool:
raise ValueError(f"tool {tool_name} not found")
# check if the builtin tool need credentials
provider_controller = cls.get_builtin_provider(provider_id)
provider_controller = cls.get_builtin_provider(provider_id, tenant_id)
if not provider_controller.need_credentials:
return cast(
BuiltinTool,
@ -131,7 +157,7 @@ class ToolManager:
# decrypt the credentials
credentials = builtin_provider.credentials
controller = cls.get_builtin_provider(provider_id)
controller = cls.get_builtin_provider(provider_id, tenant_id)
tool_configuration = ProviderConfigEncrypter(
tenant_id=tenant_id,
config=controller.get_credentials_schema(),
@ -246,7 +272,7 @@ class ToolManager:
tool_invoke_from=ToolInvokeFrom.AGENT,
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
# check file types
if parameter.type == ToolParameter.ToolParameterType.FILE:
@ -294,7 +320,7 @@ class ToolManager:
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
parameters = tool_entity.get_merged_runtime_parameters()
for parameter in parameters:
# save tool parameter to tool entity memory
@ -321,16 +347,17 @@ class ToolManager:
return tool_entity
@classmethod
def get_builtin_provider_icon(cls, provider: str) -> tuple[str, str]:
def get_builtin_provider_icon(cls, provider: str, tenant_id: str) -> tuple[str, str]:
"""
get the absolute path of the icon of the builtin provider
:param provider: the name of the provider
:param tenant_id: the id of the tenant
:return: the absolute path of the icon, the mime type of the icon
"""
# get provider
provider_controller = cls.get_builtin_provider(provider)
provider_controller = cls.get_builtin_provider(provider, tenant_id)
absolute_path = path.join(
path.dirname(path.realpath(__file__)),
@ -351,21 +378,48 @@ class ToolManager:
return absolute_path, mime_type
@classmethod
def list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
def list_hardcoded_providers(cls):
# use cache first
if cls._builtin_providers_loaded:
yield from list(cls._builtin_providers.values())
yield from list(cls._hardcoded_providers.values())
return
with cls._builtin_provider_lock:
if cls._builtin_providers_loaded:
yield from list(cls._builtin_providers.values())
yield from list(cls._hardcoded_providers.values())
return
yield from cls._list_builtin_providers()
yield from cls._list_hardcoded_providers()
@classmethod
def _list_builtin_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
def list_plugin_providers(cls, tenant_id: str) -> list[PluginToolProviderController]:
"""
list all the plugin providers
"""
manager = PluginToolManager()
provider_entities = manager.fetch_tool_providers(tenant_id)
return [
PluginToolProviderController(
entity=provider.declaration,
tenant_id=tenant_id,
plugin_unique_identifier=provider.plugin_unique_identifier,
)
for provider in provider_entities
]
@classmethod
def list_builtin_providers(
cls, tenant_id: str
) -> Generator[BuiltinToolProviderController | PluginToolProviderController, None, None]:
"""
list all the builtin providers
"""
yield from cls.list_hardcoded_providers()
# get plugin providers
yield from cls.list_plugin_providers(tenant_id)
@classmethod
def _list_hardcoded_providers(cls) -> Generator[BuiltinToolProviderController, None, None]:
"""
list all the builtin providers
"""
@ -391,7 +445,7 @@ class ToolManager:
parent_type=BuiltinToolProviderController,
)
provider: BuiltinToolProviderController = provider_class()
cls._builtin_providers[provider.entity.identity.name] = provider
cls._hardcoded_providers[provider.entity.identity.name] = provider
for tool in provider.get_tools():
cls._builtin_tools_labels[tool.entity.identity.name] = tool.entity.identity.label
yield provider
@ -403,13 +457,13 @@ class ToolManager:
cls._builtin_providers_loaded = True
@classmethod
def load_builtin_providers_cache(cls):
for _ in cls.list_builtin_providers():
def load_hardcoded_providers_cache(cls):
for _ in cls.list_hardcoded_providers():
pass
@classmethod
def clear_builtin_providers_cache(cls):
cls._builtin_providers = {}
def clear_hardcoded_providers_cache(cls):
cls._hardcoded_providers = {}
cls._builtin_providers_loaded = False
@classmethod
@ -423,7 +477,7 @@ class ToolManager:
"""
if len(cls._builtin_tools_labels) == 0:
# init the builtin providers
cls.load_builtin_providers_cache()
cls.load_hardcoded_providers_cache()
if tool_name not in cls._builtin_tools_labels:
return None
@ -432,9 +486,9 @@ class ToolManager:
@classmethod
def user_list_providers(
cls, user_id: str, tenant_id: str, typ: UserToolProviderTypeLiteral
) -> list[UserToolProvider]:
result_providers: dict[str, UserToolProvider] = {}
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
) -> list[ToolProviderApiEntity]:
result_providers: dict[str, ToolProviderApiEntity] = {}
filters = []
if not typ:
@ -444,7 +498,7 @@ class ToolManager:
if "builtin" in filters:
# get builtin providers
builtin_providers = cls.list_builtin_providers()
builtin_providers = cls.list_builtin_providers(tenant_id)
# get db builtin providers
db_builtin_providers: list[BuiltinToolProvider] = (
@ -666,4 +720,4 @@ class ToolManager:
raise ValueError(f"provider type {provider_type} not found")
ToolManager.load_builtin_providers_cache()
ToolManager.load_hardcoded_providers_cache()