refactor: tool engine

This commit is contained in:
Yeuoly
2024-03-28 18:36:58 +08:00
parent 82a82fff35
commit 51404f9035
8 changed files with 318 additions and 283 deletions

View File

@ -4,7 +4,6 @@ from typing import Any, Optional, Union
from pydantic import BaseModel
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
@ -22,8 +21,6 @@ class Tool(BaseModel, ABC):
parameters: Optional[list[ToolParameter]] = None
description: ToolDescription = None
is_team_authorization: bool = False
agent_callback: Optional[DifyAgentCallbackHandler] = None
use_callback: bool = False
class Runtime(BaseModel):
"""
@ -45,15 +42,10 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
if not self.agent_callback:
self.use_callback = False
else:
self.use_callback = True
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@ -65,7 +57,6 @@ class Tool(BaseModel, ABC):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
agent_callback=agent_callback
)
def load_variables(self, variables: ToolRuntimeVariablePool):
@ -174,50 +165,19 @@ class Tool(BaseModel, ABC):
return result
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
# check if tool_parameters is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
if parameters and len(parameters) == 1:
tool_parameters = {
parameters[0].name: tool_parameters
}
else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# hit callback
if self.use_callback:
self.agent_callback.on_tool_start(
tool_name=self.identity.name,
tool_inputs=tool_parameters
)
try:
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
except Exception as e:
if self.use_callback:
self.agent_callback.on_tool_error(e)
raise e
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
if not isinstance(result, list):
result = [result]
# hit callback
if self.use_callback:
self.agent_callback.on_tool_end(
tool_name=self.identity.name,
tool_inputs=tool_parameters,
tool_outputs=self._convert_tool_response_to_str(result)
)
return result
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str: