refactor: tool

This commit is contained in:
Yeuoly
2024-09-20 23:48:48 +08:00
parent 3c1d32e3ac
commit 91cb80f795
29 changed files with 498 additions and 906 deletions

View File

@ -2,13 +2,12 @@ from abc import abstractmethod
from os import listdir, path
from typing import Any
from pydantic import Field
from core.entities.provider_entities import ProviderConfig
from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@ -17,10 +16,10 @@ from core.tools.utils.yaml_utils import load_yaml_file
class BuiltinToolProviderController(ToolProviderController):
tools: list[BuiltinTool] = Field(default_factory=list)
tools: list[BuiltinTool]
def __init__(self, **data: Any) -> None:
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
if self.provider_type == ToolProviderType.API:
super().__init__(**data)
return
@ -37,10 +36,12 @@ class BuiltinToolProviderController(ToolProviderController):
for credential_name in provider_yaml["credentials_for_provider"]:
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
super().__init__(**{
'identity': provider_yaml['identity'],
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
})
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
),
)
def _get_builtin_tools(self) -> list[BuiltinTool]:
"""
@ -51,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController):
if self.tools:
return self.tools
provider = self.identity.name
provider = self.entity.identity.name
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
# get all the yaml files in the tool path
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
@ -62,30 +63,36 @@ class BuiltinToolProviderController(ToolProviderController):
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
# get tool class, import the module
assistant_tool_class = load_single_subclass_from_source(
assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
script_path=path.join(
path.dirname(path.realpath(__file__)),
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py"
path.dirname(path.realpath(__file__)),
"builtin_tool",
"providers",
provider,
"tools",
f"{tool_name}.py",
),
parent_type=BuiltinTool,
)
tool["identity"]["provider"] = provider
tools.append(assistant_tool_class(**tool))
tools.append(assistant_tool_class(
entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""),
))
self.tools = tools
return tools
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
"""
returns the credentials schema of the provider
:return: the credentials schema
"""
if not self.credentials_schema:
if not self.entity.credentials_schema:
return {}
return self.credentials_schema.copy()
return self.entity.credentials_schema.copy()
def get_tools(self) -> list[BuiltinTool]:
"""
@ -94,12 +101,12 @@ class BuiltinToolProviderController(ToolProviderController):
:return: list of tools
"""
return self._get_builtin_tools()
def get_tool(self, tool_name: str) -> BuiltinTool | None:
"""
returns the tool that the provider can provide
"""
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
@property
def need_credentials(self) -> bool:
@ -108,7 +115,7 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
return self.credentials_schema is not None and len(self.credentials_schema) != 0
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
@property
def provider_type(self) -> ToolProviderType:
@ -133,8 +140,8 @@ class BuiltinToolProviderController(ToolProviderController):
"""
returns the labels of the provider
"""
return self.identity.tags or []
return self.entity.identity.tags or []
def validate_credentials(self, credentials: dict[str, Any]) -> None:
"""
validate the credentials of the provider

View File

@ -1,13 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
from core.tools.errors import ToolProviderCredentialValidationError
class QRCodeProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
pass

View File

@ -1,16 +1,8 @@
from typing import Any
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
from core.tools.errors import ToolProviderCredentialValidationError
class WikiPediaProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
try:
CurrentTimeTool().invoke(
user_id="",
tool_parameters={},
)
except Exception as e:
raise ToolProviderCredentialValidationError(str(e))
pass

View File

@ -32,9 +32,9 @@ class BuiltinTool(Tool):
# invoke model
return ModelInvocationUtils.invoke(
user_id=user_id,
tenant_id=self.runtime.tenant_id or "",
tenant_id=self.runtime.tenant_id,
tool_type="builtin",
tool_name=self.identity.name,
tool_name=self.entity.identity.name,
prompt_messages=prompt_messages,
)
@ -79,6 +79,7 @@ class BuiltinTool(Tool):
stop=[],
)
assert isinstance(summary.message.content, str)
return summary.message.content
lines = content.split("\n")