mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
refactor: tool engine
This commit is contained in:
@ -4,7 +4,6 @@ from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
@ -22,8 +21,6 @@ class Tool(BaseModel, ABC):
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
agent_callback: Optional[DifyAgentCallbackHandler] = None
|
||||
use_callback: bool = False
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@ -45,15 +42,10 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
if not self.agent_callback:
|
||||
self.use_callback = False
|
||||
else:
|
||||
self.use_callback = True
|
||||
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@ -65,7 +57,6 @@ class Tool(BaseModel, ABC):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
@ -174,50 +165,19 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
|
||||
# check if tool_parameters is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_start(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
except Exception as e:
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_error(e)
|
||||
raise e
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_end(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=self._convert_tool_response_to_str(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
|
||||
Reference in New Issue
Block a user