fix: invoke tool streamingly

This commit is contained in:
Yeuoly
2024-08-30 18:11:38 +08:00
parent cf4e9f317e
commit 886a160115
16 changed files with 149 additions and 92 deletions

View File

@ -1,6 +1,6 @@
from collections.abc import Generator, Sequence
from os import path
from typing import Any, cast
from typing import Any, Iterable, cast
from core.app.segments import ArrayAnySegment, ArrayAnyVariable, parser
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
@ -158,14 +158,17 @@ class ToolNode(BaseNode):
tenant_id=self.tenant_id,
conversation_id=None,
)
result = list(messages)
# extract plain text and files
files = self._extract_tool_response_binary(messages)
plain_text = self._extract_tool_response_text(messages)
json = self._extract_tool_response_json(messages)
files = self._extract_tool_response_binary(result)
plain_text = self._extract_tool_response_text(result)
json = self._extract_tool_response_json(result)
return plain_text, files, json
def _extract_tool_response_binary(self, tool_response: Generator[ToolInvokeMessage, None, None]) -> list[FileVar]:
def _extract_tool_response_binary(self, tool_response: Iterable[ToolInvokeMessage]) -> list[FileVar]:
"""
Extract tool response binary
"""
@ -215,7 +218,7 @@ class ToolNode(BaseNode):
return result
def _extract_tool_response_text(self, tool_response: Generator[ToolInvokeMessage]) -> str:
def _extract_tool_response_text(self, tool_response: Iterable[ToolInvokeMessage]) -> str:
"""
Extract tool response text
"""
@ -230,7 +233,7 @@ class ToolNode(BaseNode):
return '\n'.join(result)
def _extract_tool_response_json(self, tool_response: Generator[ToolInvokeMessage]) -> list[dict]:
def _extract_tool_response_json(self, tool_response: Iterable[ToolInvokeMessage]) -> list[dict]:
result: list[dict] = []
for message in tool_response:
if message.type == ToolInvokeMessage.MessageType.JSON: