fix: tool node

This commit is contained in:
Yeuoly
2024-08-29 13:56:48 +08:00
parent c28998a6f0
commit 531ffaec4f

View File

@ -128,8 +128,10 @@ class ToolNode(BaseNode):
else:
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == 'variable':
# TODO: check if the variable exists in the variable pool
parameter_value = variable_pool.get(tool_input.value).value
parameter_value_segment = variable_pool.get(tool_input.value)
if not parameter_value_segment:
raise Exception("input variable dose not exists")
parameter_value = parameter_value_segment.value
else:
segment_group = parser.convert_template(
template=str(tool_input.value),
@ -163,7 +165,7 @@ class ToolNode(BaseNode):
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
"""
Extract tool response binary
"""
@ -172,7 +174,10 @@ class ToolNode(BaseNode):
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
url = response.message
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
url = response.message.text
ext = path.splitext(url)[1]
mimetype = response.meta.get('mime_type', 'image/jpeg')
filename = response.save_as or url.split('/')[-1]
@ -192,7 +197,10 @@ class ToolNode(BaseNode):
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# get tool file id
tool_file_id = response.message.split('/')[-1].split('.')[0]
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
assert response.meta
tool_file_id = response.message.text.split('/')[-1].split('.')[0]
result.append(FileVar(
tenant_id=self.tenant_id,
type=FileType.IMAGE,
@ -207,18 +215,28 @@ class ToolNode(BaseNode):
return result
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
return '\n'.join([
f'{message.message}' if message.type == ToolInvokeMessage.MessageType.TEXT else
f'Link: {message.message}' if message.type == ToolInvokeMessage.MessageType.LINK else ''
for message in tool_response
])
result: list[str] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(message.message.text)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
result.append(f'Link: {message.message.text}')
def _extract_tool_response_json(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
return [message.message for message in tool_response if message.type == ToolInvokeMessage.MessageType.JSON]
return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message, ToolInvokeMessage.JsonMessage)
result.append(message.json_object)
return result
@classmethod
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
@ -231,6 +249,7 @@ class ToolNode(BaseNode):
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector