mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
refactor: tool
This commit is contained in:
@ -2,13 +2,12 @@ from abc import abstractmethod
|
||||
from os import listdir, path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.helper.module_import_helper import load_single_subclass_from_source
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.builtin_tool.tool import BuiltinTool
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
|
||||
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
|
||||
from core.tools.errors import (
|
||||
ToolProviderNotFoundError,
|
||||
@ -17,10 +16,10 @@ from core.tools.utils.yaml_utils import load_yaml_file
|
||||
|
||||
|
||||
class BuiltinToolProviderController(ToolProviderController):
|
||||
tools: list[BuiltinTool] = Field(default_factory=list)
|
||||
tools: list[BuiltinTool]
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
if self.provider_type in {ToolProviderType.API, ToolProviderType.APP}:
|
||||
if self.provider_type == ToolProviderType.API:
|
||||
super().__init__(**data)
|
||||
return
|
||||
|
||||
@ -37,10 +36,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
for credential_name in provider_yaml["credentials_for_provider"]:
|
||||
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
|
||||
|
||||
super().__init__(**{
|
||||
'identity': provider_yaml['identity'],
|
||||
'credentials_schema': provider_yaml.get('credentials_for_provider', {}) or {},
|
||||
})
|
||||
super().__init__(
|
||||
entity=ToolProviderEntity(
|
||||
identity=provider_yaml["identity"],
|
||||
credentials_schema=provider_yaml.get("credentials_for_provider", {}) or {},
|
||||
),
|
||||
)
|
||||
|
||||
def _get_builtin_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@ -51,7 +52,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
if self.tools:
|
||||
return self.tools
|
||||
|
||||
provider = self.identity.name
|
||||
provider = self.entity.identity.name
|
||||
tool_path = path.join(path.dirname(path.realpath(__file__)), "providers", provider, "tools")
|
||||
# get all the yaml files in the tool path
|
||||
tool_files = list(filter(lambda x: x.endswith(".yaml") and not x.startswith("__"), listdir(tool_path)))
|
||||
@ -62,30 +63,36 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
tool = load_yaml_file(path.join(tool_path, tool_file), ignore_error=False)
|
||||
|
||||
# get tool class, import the module
|
||||
assistant_tool_class = load_single_subclass_from_source(
|
||||
assistant_tool_class: type[BuiltinTool] = load_single_subclass_from_source(
|
||||
module_name=f"core.tools.builtin_tool.providers.{provider}.tools.{tool_name}",
|
||||
script_path=path.join(
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool", "providers", provider, "tools", f"{tool_name}.py"
|
||||
path.dirname(path.realpath(__file__)),
|
||||
"builtin_tool",
|
||||
"providers",
|
||||
provider,
|
||||
"tools",
|
||||
f"{tool_name}.py",
|
||||
),
|
||||
parent_type=BuiltinTool,
|
||||
)
|
||||
tool["identity"]["provider"] = provider
|
||||
tools.append(assistant_tool_class(**tool))
|
||||
tools.append(assistant_tool_class(
|
||||
entity=ToolEntity(**tool), runtime=ToolRuntime(tenant_id=""),
|
||||
))
|
||||
|
||||
self.tools = tools
|
||||
return tools
|
||||
|
||||
|
||||
def get_credentials_schema(self) -> dict[str, ProviderConfig]:
|
||||
"""
|
||||
returns the credentials schema of the provider
|
||||
|
||||
:return: the credentials schema
|
||||
"""
|
||||
if not self.credentials_schema:
|
||||
if not self.entity.credentials_schema:
|
||||
return {}
|
||||
|
||||
return self.credentials_schema.copy()
|
||||
return self.entity.credentials_schema.copy()
|
||||
|
||||
def get_tools(self) -> list[BuiltinTool]:
|
||||
"""
|
||||
@ -94,12 +101,12 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
:return: list of tools
|
||||
"""
|
||||
return self._get_builtin_tools()
|
||||
|
||||
|
||||
def get_tool(self, tool_name: str) -> BuiltinTool | None:
|
||||
"""
|
||||
returns the tool that the provider can provide
|
||||
"""
|
||||
return next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
|
||||
return next(filter(lambda x: x.entity.identity.name == tool_name, self.get_tools()), None)
|
||||
|
||||
@property
|
||||
def need_credentials(self) -> bool:
|
||||
@ -108,7 +115,7 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
|
||||
:return: whether the provider needs credentials
|
||||
"""
|
||||
return self.credentials_schema is not None and len(self.credentials_schema) != 0
|
||||
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||
|
||||
@property
|
||||
def provider_type(self) -> ToolProviderType:
|
||||
@ -133,8 +140,8 @@ class BuiltinToolProviderController(ToolProviderController):
|
||||
"""
|
||||
returns the labels of the provider
|
||||
"""
|
||||
return self.identity.tags or []
|
||||
|
||||
return self.entity.identity.tags or []
|
||||
|
||||
def validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
|
||||
@ -1,13 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.qrcode.tools.qrcode_generator import QRCodeGeneratorTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class QRCodeProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
QRCodeGeneratorTool().invoke(user_id="", tool_parameters={"content": "Dify 123 😊"})
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
pass
|
||||
|
||||
@ -1,16 +1,8 @@
|
||||
from typing import Any
|
||||
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
from core.tools.builtin_tool.providers.time.tools.current_time import CurrentTimeTool
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class WikiPediaProvider(BuiltinToolProviderController):
|
||||
def _validate_credentials(self, credentials: dict[str, Any]) -> None:
|
||||
try:
|
||||
CurrentTimeTool().invoke(
|
||||
user_id="",
|
||||
tool_parameters={},
|
||||
)
|
||||
except Exception as e:
|
||||
raise ToolProviderCredentialValidationError(str(e))
|
||||
pass
|
||||
|
||||
@ -32,9 +32,9 @@ class BuiltinTool(Tool):
|
||||
# invoke model
|
||||
return ModelInvocationUtils.invoke(
|
||||
user_id=user_id,
|
||||
tenant_id=self.runtime.tenant_id or "",
|
||||
tenant_id=self.runtime.tenant_id,
|
||||
tool_type="builtin",
|
||||
tool_name=self.identity.name,
|
||||
tool_name=self.entity.identity.name,
|
||||
prompt_messages=prompt_messages,
|
||||
)
|
||||
|
||||
@ -79,6 +79,7 @@ class BuiltinTool(Tool):
|
||||
stop=[],
|
||||
)
|
||||
|
||||
assert isinstance(summary.message.content, str)
|
||||
return summary.message.content
|
||||
|
||||
lines = content.split("\n")
|
||||
|
||||
Reference in New Issue
Block a user