mirror of
https://github.com/langgenius/dify.git
synced 2026-04-28 14:38:06 +08:00
feat: tool node
This commit is contained in:
23
api/core/workflow/nodes/tool/entities.py
Normal file
23
api/core/workflow/nodes/tool/entities.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.variable_entities import VariableSelector
|
||||
|
||||
ToolParameterValue = Union[str, int, float, bool]
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
provider_type: Literal['builtin', 'api']
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_parameters: dict[str, ToolParameterValue]
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
tool_inputs: list[VariableSelector]
|
||||
@ -1,5 +1,139 @@
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class ToolNode(BaseNode):
|
||||
pass
|
||||
"""
|
||||
Tool Node
|
||||
"""
|
||||
_node_data_cls = ToolNodeData
|
||||
_node_type = NodeType.TOOL
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> NodeRunResult:
|
||||
"""
|
||||
Run the tool node
|
||||
"""
|
||||
|
||||
node_data = cast(ToolNodeData, self.node_data)
|
||||
|
||||
# extract tool parameters
|
||||
parameters = {
|
||||
k.variable: variable_pool.get_variable_value(k.value_selector)
|
||||
for k in node_data.tool_inputs
|
||||
}
|
||||
|
||||
if len(parameters) != len(node_data.tool_inputs):
|
||||
raise ValueError('Invalid tool parameters')
|
||||
|
||||
# get tool runtime
|
||||
try:
|
||||
tool_runtime = ToolManager.get_workflow_tool_runtime(self.tenant_id, node_data, None)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
error=f'Failed to get tool runtime: {str(e)}'
|
||||
)
|
||||
|
||||
try:
|
||||
messages = tool_runtime.invoke(None, parameters)
|
||||
except Exception as e:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters,
|
||||
error=f'Failed to invoke tool: {str(e)}'
|
||||
)
|
||||
|
||||
# convert tool messages
|
||||
plain_text, files = self._convert_tool_messages(messages)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCESS,
|
||||
outputs={
|
||||
'text': plain_text,
|
||||
'files': files
|
||||
},
|
||||
)
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
messages = ToolFileMessageTransformer.transform_tool_invoke_messages(messages)
|
||||
# extract plain text and files
|
||||
files = self._extract_tool_response_binary(messages)
|
||||
plain_text = self._extract_tool_response_text(messages)
|
||||
|
||||
return plain_text, files
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
result.append({
|
||||
'type': 'image',
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': url,
|
||||
'upload_file_id': None,
|
||||
'filename': filename,
|
||||
'file-ext': ext,
|
||||
'mime-type': mimetype,
|
||||
})
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append({
|
||||
'type': 'image', # TODO: only support image for now
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': response.message,
|
||||
'upload_file_id': None,
|
||||
'filename': response.save_as,
|
||||
'file-ext': path.splitext(response.save_as)[1],
|
||||
'mime-type': response.meta.get('mime_type', 'application/octet-stream'),
|
||||
})
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
|
||||
return result
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return ''.join([
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
def _convert_tool_file(message: list[ToolInvokeMessage]) -> dict:
|
||||
"""
|
||||
Convert ToolInvokeMessage into file
|
||||
"""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: BaseNodeData) -> dict[list[str], str]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
"""
|
||||
pass
|
||||
Reference in New Issue
Block a user