feat(api): Add image multimodal support for LLMNode (#17372)

Enhance `LLMNode` with multimodal capability, introducing support for
image outputs.

This implementation extracts base64-encoded images from LLM responses,
saves them to the storage service, and records the file metadata in the
`ToolFile` table. In conversations, these images are rendered as
markdown-based inline images.
Additionally, the images are included in the LLMNode's output as
file variables, enabling subsequent nodes in the workflow to utilize them.

To integrate file outputs into workflows, adjustments to the frontend code
are necessary.

For multimodal output functionality, updates to related model configurations
are required. Currently, this capability has been applied exclusively to
Google's Gemini models.

Close #15814.

Signed-off-by: -LAN- <laipz8200@outlook.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
This commit is contained in:
QuantumGhost
2025-04-30 17:28:02 +08:00
committed by GitHub
parent 6c9a9d344a
commit 349c3cf7b8
24 changed files with 971 additions and 191 deletions

View File

@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file = ToolFileManager.create_file_by_raw(
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,

View File

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

View File

@ -0,0 +1,160 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
class LLMFileSaver(tp.Protocol):
"""LLMFileSaver is responsible for save multimodal output returned by
LLM.
"""
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
pass
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
Currently (2025-04-30), no model returns multimodel output as a url.
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=data,
mimetype=mime_type,
)
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

View File

@ -1,3 +1,5 @@
import base64
import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -95,9 +96,13 @@ from .exc import (
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
"""Process structured output if enabled"""
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
except Exception as e:
logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
event = self._handle_blocking_result(invoke_result=invoke_result)
yield event
return
model = None
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
for result in invoke_result:
text = convert_llm_result_chunk_to_str(result.delta.message.content)
full_text += text
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if not model:
# Update the whole metadata
if not model and result.model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
if len(prompt_messages) == 0:
# TODO(QuantumGhost): it seems that this update has no visable effect.
# What's the purpose of the line below?
prompt_messages = list(result.prompt_messages)
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
usage=invoke_result.usage,
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
- Inlined data encoded in base64, which would be saved to storage directly.
- Remote files referenced by an url, which would be downloaded and then saved to storage.
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]):
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
If the messages contain non-textual content (e.g., multimedia like images or videos),
it will be saved separately, and the corresponding Markdown representation will
be yielded to the caller.
"""
# NOTE(QuantumGhost): This function should yield results to the caller immediately
# whenever new content or partial content is available. Avoid any intermediate buffering
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
# if necessary.
if contents is None:
yield from []
return
if isinstance(contents, str):
yield contents
elif isinstance(contents, list):
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
else:
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole