mirror of
https://github.com/langgenius/dify.git
synced 2026-01-19 11:45:05 +08:00
fix: tool node
This commit is contained in:
@ -128,8 +128,10 @@ class ToolNode(BaseNode):
|
||||
else:
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == 'variable':
|
||||
# TODO: check if the variable exists in the variable pool
|
||||
parameter_value = variable_pool.get(tool_input.value).value
|
||||
parameter_value_segment = variable_pool.get(tool_input.value)
|
||||
if not parameter_value_segment:
|
||||
raise Exception("input variable dose not exists")
|
||||
parameter_value = parameter_value_segment.value
|
||||
else:
|
||||
segment_group = parser.convert_template(
|
||||
template=str(tool_input.value),
|
||||
@ -163,7 +165,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return plain_text, files, json
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@ -172,7 +174,10 @@ class ToolNode(BaseNode):
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
url = response.message
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
assert response.meta
|
||||
|
||||
url = response.message.text
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
@ -192,7 +197,10 @@ class ToolNode(BaseNode):
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split('/')[-1].split('.')[0]
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
assert response.meta
|
||||
|
||||
tool_file_id = response.message.text.split('/')[-1].split('.')[0]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
@ -207,18 +215,28 @@ class ToolNode(BaseNode):
|
||||
|
||||
return result
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return '\n'.join([
|
||||
f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
result: list[str] = []
|
||||
for message in tool_response:
|
||||
if message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
result.append(message.message.text)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
result.append(f'Link: {message.message.text}')
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
|
||||
return '\n'.join(result)
|
||||
|
||||
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
|
||||
result: list[dict] = []
|
||||
for message in tool_response:
|
||||
if message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message, ToolInvokeMessage.JsonMessage)
|
||||
result.append(message.json_object)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
@ -231,6 +249,7 @@ class ToolNode(BaseNode):
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
assert isinstance(input.value, str)
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
|
||||
Reference in New Issue
Block a user