mirror of
https://github.com/langgenius/dify.git
synced 2026-05-05 01:48:04 +08:00
add image file as markdown stream outupt
This commit is contained in:
@ -1,9 +1,10 @@
|
||||
from enum import Enum
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.node_entities import SystemVariable
|
||||
|
||||
VariableValue = Union[str, int, float, dict, list]
|
||||
VariableValue = Union[str, int, float, dict, list, FileVar]
|
||||
|
||||
|
||||
class ValueType(Enum):
|
||||
|
||||
@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
|
||||
from core.file.file_obj import FileObj
|
||||
from core.file.file_obj import FileVar
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@ -51,15 +51,10 @@ class LLMNode(BaseNode):
|
||||
}
|
||||
|
||||
# fetch files
|
||||
files: list[FileObj] = self._fetch_files(node_data, variable_pool)
|
||||
files: list[FileVar] = self._fetch_files(node_data, variable_pool)
|
||||
|
||||
if files:
|
||||
node_inputs['#files#'] = [{
|
||||
'type': file.type.value,
|
||||
'transfer_method': file.transfer_method.value,
|
||||
'url': file.url,
|
||||
'upload_file_id': file.upload_file_id,
|
||||
} for file in files]
|
||||
node_inputs['#files#'] = [file.to_dict() for file in files]
|
||||
|
||||
# fetch context value
|
||||
context = self._fetch_context(node_data, variable_pool)
|
||||
@ -202,7 +197,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
return inputs
|
||||
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileObj]:
|
||||
def _fetch_files(self, node_data: LLMNodeData, variable_pool: VariablePool) -> list[FileVar]:
|
||||
"""
|
||||
Fetch files
|
||||
:param node_data: node data
|
||||
@ -350,7 +345,7 @@ class LLMNode(BaseNode):
|
||||
|
||||
def _fetch_prompt_messages(self, node_data: LLMNodeData,
|
||||
inputs: dict[str, str],
|
||||
files: list[FileObj],
|
||||
files: list[FileVar],
|
||||
context: Optional[str],
|
||||
memory: Optional[TokenBufferMemory],
|
||||
model_config: ModelConfigWithCredentialsEntity) \
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
from os import path
|
||||
from typing import cast
|
||||
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
@ -58,19 +58,19 @@ class ToolNode(BaseNode):
|
||||
},
|
||||
inputs=parameters
|
||||
)
|
||||
|
||||
|
||||
def _generate_parameters(self, variable_pool: VariablePool, node_data: ToolNodeData) -> dict:
|
||||
"""
|
||||
Generate parameters
|
||||
"""
|
||||
return {
|
||||
k.variable:
|
||||
k.value if k.variable_type == 'static' else
|
||||
k.variable:
|
||||
k.value if k.variable_type == 'static' else
|
||||
variable_pool.get_variable_value(k.value_selector) if k.variable_type == 'selector' else ''
|
||||
for k in node_data.tool_parameters
|
||||
}
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[dict]]:
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
@ -87,7 +87,7 @@ class ToolNode(BaseNode):
|
||||
|
||||
return plain_text, files
|
||||
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[dict]:
|
||||
def _extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[FileVar]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
@ -95,46 +95,50 @@ class ToolNode(BaseNode):
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
url = response.message
|
||||
ext = path.splitext(url)[1]
|
||||
mimetype = response.meta.get('mime_type', 'image/jpeg')
|
||||
filename = response.save_as or url.split('/')[-1]
|
||||
result.append({
|
||||
'type': 'image',
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': url,
|
||||
'upload_file_id': None,
|
||||
'filename': filename,
|
||||
'file-ext': ext,
|
||||
'mime-type': mimetype,
|
||||
})
|
||||
|
||||
# get tool file id
|
||||
tool_file_id = url.split('/')[-1]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=filename,
|
||||
extension=ext,
|
||||
mime_type=mimetype,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append({
|
||||
'type': 'image', # TODO: only support image for now
|
||||
'transfer_method': FileTransferMethod.TOOL_FILE,
|
||||
'url': response.message,
|
||||
'upload_file_id': None,
|
||||
'filename': response.save_as,
|
||||
'file-ext': path.splitext(response.save_as)[1],
|
||||
'mime-type': response.meta.get('mime_type', 'application/octet-stream'),
|
||||
})
|
||||
# get tool file id
|
||||
tool_file_id = response.message.split('/')[-1]
|
||||
result.append(FileVar(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=tool_file_id,
|
||||
filename=response.save_as,
|
||||
extension=path.splitext(response.save_as)[1],
|
||||
mime_type=response.meta.get('mime_type', 'application/octet-stream'),
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
pass # TODO:
|
||||
pass # TODO:
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_tool_response_text(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Extract tool response text
|
||||
"""
|
||||
return ''.join([
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'{message.message}\n' if message.type == ToolInvokeMessage.MessageType.TEXT else
|
||||
f'Link: {message.message}\n' if message.type == ToolInvokeMessage.MessageType.LINK else ''
|
||||
for message in tool_response
|
||||
])
|
||||
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: ToolNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user