feat: add context file support

This commit is contained in:
Novice
2026-01-16 17:01:19 +08:00
parent e85e31773a
commit 18abc66585
7 changed files with 585 additions and 9 deletions

View File

@ -1,4 +1,5 @@
import base64
import logging
from collections.abc import Mapping
from configs import dify_config
@ -10,7 +11,10 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.model_runtime.entities.message_entities import (
MultiModalPromptMessageContent,
PromptMessageContentUnionTypes,
)
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
@ -18,6 +22,8 @@ from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
@ -89,6 +95,8 @@ def to_prompt_message_content(
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
"file_ref": _encode_file_ref(f),
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
@ -96,6 +104,17 @@ def to_prompt_message_content(
return prompt_class_map[f.type].model_validate(params)
def _encode_file_ref(f: File) -> str | None:
"""Encode file reference as 'transfer_method:id_or_url' string."""
if f.transfer_method == FileTransferMethod.REMOTE_URL:
return f"remote:{f.remote_url}" if f.remote_url else None
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
return f"local:{f.related_id}" if f.related_id else None
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
return f"tool:{f.related_id}" if f.related_id else None
return None
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
@ -164,3 +183,128 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def restore_multimodal_content(
content: MultiModalPromptMessageContent,
) -> MultiModalPromptMessageContent:
"""
Restore base64_data or url for multimodal content from file_ref.
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
Args:
content: MultiModalPromptMessageContent with file_ref field
Returns:
MultiModalPromptMessageContent with restored base64_data or url
"""
# Skip if no file reference or content already has data
if not content.file_ref:
return content
if content.base64_data or content.url:
return content
try:
file = _build_file_from_ref(
file_ref=content.file_ref,
file_format=content.format,
mime_type=content.mime_type,
filename=content.filename,
)
if not file:
return content
# Restore content based on config
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
restored_base64 = _get_encoded_string(file)
return content.model_copy(update={"base64_data": restored_base64})
else:
restored_url = _to_url(file)
return content.model_copy(update={"url": restored_url})
except Exception as e:
logger.warning("Failed to restore multimodal content: %s", e)
return content
def _build_file_from_ref(
file_ref: str,
file_format: str | None,
mime_type: str | None,
filename: str | None,
) -> File | None:
"""
Build a File object from encoded file_ref string.
Args:
file_ref: Encoded reference "transfer_method:id_or_url"
file_format: The file format/extension (without dot)
mime_type: The mime type
filename: The filename
Returns:
File object with storage_key loaded, or None if not found
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import UploadFile
from models.tools import ToolFile
# Parse file_ref: "method:value"
if ":" not in file_ref:
logger.warning("Invalid file_ref format: %s", file_ref)
return None
method, value = file_ref.split(":", 1)
extension = f".{file_format}" if file_format else None
if method == "remote":
return File(
tenant_id="",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=value,
extension=extension,
mime_type=mime_type,
filename=filename,
storage_key="",
)
# Query database for storage_key
with Session(db.engine) as session:
if method == "local":
stmt = select(UploadFile).where(UploadFile.id == value)
upload_file = session.scalar(stmt)
if upload_file:
return File(
tenant_id=upload_file.tenant_id,
type=FileType(upload_file.extension)
if hasattr(FileType, upload_file.extension.upper())
else FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=value,
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
mime_type=mime_type or upload_file.mime_type,
filename=filename or upload_file.name,
storage_key=upload_file.key,
)
elif method == "tool":
stmt = select(ToolFile).where(ToolFile.id == value)
tool_file = session.scalar(stmt)
if tool_file:
return File(
tenant_id=tool_file.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=value,
extension=extension,
mime_type=mime_type or tool_file.mimetype,
filename=filename or tool_file.name,
storage_key=tool_file.file_key,
)
logger.warning("File not found for file_ref: %s", file_ref)
return None

View File

@ -15,20 +15,24 @@ Design:
import logging
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import file_manager
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
@ -108,11 +112,36 @@ class NodeTokenBufferMemory(BaseMemory):
messages = []
for msg_dict in context_data:
try:
messages.append(self._deserialize_prompt_message(msg_dict))
msg = self._deserialize_prompt_message(msg_dict)
msg = self._restore_multimodal_content(msg)
messages.append(msg)
except Exception as e:
logger.warning("Failed to deserialize prompt message: %s", e)
return messages
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
"""
Restore multimodal content (base64 or url) from file_ref.
When context is saved, base64_data is cleared to save storage space.
This method restores the content by parsing file_ref (format: "method:id_or_url").
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, restoring multimodal data from file references
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# restore_multimodal_content preserves the concrete subclass type
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})
def get_history_prompt_messages(
self,
*,

View File

@ -91,6 +91,9 @@ class MultiModalPromptMessageContent(PromptMessageContent):
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
filename: str = Field(default="", description="the filename of multi-modal file")
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
@property
def data(self):
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"

View File

@ -233,21 +233,63 @@ def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
If file_ref is present, clears base64_data and url (they can be restored later).
Otherwise, truncates base64_data as fallback for legacy data.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, truncating multi-modal base64 data
# Process list content, handling multi-modal data based on file_ref availability
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
if item.file_ref:
# Clear base64 and url, keep file_ref for later restoration
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
else:
# Fallback: truncate base64_data if no file_ref (legacy data)
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
"""
Restore multimodal content (base64 or url) in a list of PromptMessages.
When context is saved, base64_data is cleared to save storage space.
This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.
Args:
messages: List of PromptMessages that may contain truncated multimodal content
Returns:
List of PromptMessages with restored multimodal content
"""
from core.file import file_manager
return [_restore_message_content(msg, file_manager) for msg in messages]
def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
"""Restore multimodal content in a single PromptMessage."""
content = message.content
if content is None or isinstance(content, str):
return message
restored_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
restored_item = file_manager.restore_multimodal_content(item)
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
else:
restored_content.append(item)
return message.model_copy(update={"content": restored_content})

View File

@ -675,7 +675,9 @@ class LLMNode(Node[LLMNodeData]):
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
combined_messages.extend(ctx_var.value)
# Restore multimodal content (base64/url) that was truncated when saving context
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
combined_messages.extend(restored_messages)
context_idx += 1
else:
# Handle static message

View File

@ -0,0 +1,182 @@
"""Tests for file_manager module, specifically multimodal content handling."""
from unittest.mock import patch
from core.file import File, FileTransferMethod, FileType
from core.file.file_manager import (
_encode_file_ref,
restore_multimodal_content,
to_prompt_message_content,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
class TestEncodeFileRef:
"""Tests for _encode_file_ref function."""
def test_encodes_local_file(self):
"""Local file should be encoded as 'local:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="abc123",
storage_key="key",
)
assert _encode_file_ref(file) == "local:abc123"
def test_encodes_tool_file(self):
"""Tool file should be encoded as 'tool:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="xyz789",
storage_key="key",
)
assert _encode_file_ref(file) == "tool:xyz789"
def test_encodes_remote_url(self):
"""Remote URL should be encoded as 'remote:url'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.png",
storage_key="",
)
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
class TestToPromptMessageContent:
"""Tests for to_prompt_message_content function with file_ref field."""
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._get_encoded_string")
def test_includes_file_ref(self, mock_get_encoded, mock_config):
"""Generated content should include file_ref field."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_get_encoded.return_value = "base64data"
file = File(
id="test-message-file-id",
tenant_id="test-tenant",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-related-id",
remote_url=None,
extension=".png",
mime_type="image/png",
filename="test.png",
storage_key="test-key",
)
result = to_prompt_message_content(file)
assert isinstance(result, ImagePromptMessageContent)
assert result.file_ref == "local:test-related-id"
assert result.base64_data == "base64data"
class TestRestoreMultimodalContent:
"""Tests for restore_multimodal_content function."""
def test_returns_content_unchanged_when_no_file_ref(self):
"""Content without file_ref should pass through unchanged."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref=None,
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_data(self):
"""Content that already has base64_data should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_url(self):
"""Content that already has url should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://example.com/image.png"
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._to_url")
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
"""Content should be restored from file_ref when url is empty (url mode)."""
mock_config.MULTIMODAL_SEND_FORMAT = "url"
mock_build_file.return_value = "mock_file"
mock_to_url.return_value = "https://restored-url.com/image.png"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://restored-url.com/image.png"
mock_build_file.assert_called_once()
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._get_encoded_string")
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
"""Content should be restored as base64 when in base64 mode."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_build_file.return_value = "mock_file"
mock_get_encoded.return_value = "restored-base64-data"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "restored-base64-data"
mock_build_file.assert_called_once()
def test_handles_invalid_file_ref_gracefully(self):
"""Invalid file_ref format should be handled gracefully."""
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
file_ref="invalid_format_no_colon",
)
result = restore_multimodal_content(content)
# Should return unchanged on error
assert result.base64_data == ""

View File

@ -0,0 +1,174 @@
"""Tests for llm_utils module, specifically multimodal content handling."""
import string
from unittest.mock import patch
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from core.workflow.nodes.llm.llm_utils import (
_truncate_multimodal_content,
build_context,
restore_multimodal_content_in_messages,
)
class TestTruncateMultimodalContent:
"""Tests for _truncate_multimodal_content function."""
def test_returns_message_unchanged_for_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello, world!")
result = _truncate_multimodal_content(message)
assert result.content == "Hello, world!"
def test_returns_message_unchanged_for_none_content(self):
"""None content should pass through unchanged."""
message = UserPromptMessage(content=None)
result = _truncate_multimodal_content(message)
assert result.content is None
def test_clears_base64_when_file_ref_present(self):
"""When file_ref is present, base64_data and url should be cleared."""
image_content = ImagePromptMessageContent(
format="png",
base64_data=string.ascii_lowercase,
url="https://example.com/image.png",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 1
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
assert result_content.base64_data == ""
assert result_content.url == ""
# file_ref should be preserved
assert result_content.file_ref == "local:test-file-id"
def test_truncates_base64_when_no_file_ref(self):
"""When file_ref is missing (legacy), base64_data should be truncated."""
long_base64 = "a" * 100
image_content = ImagePromptMessageContent(
format="png",
base64_data=long_base64,
mime_type="image/png",
filename="test.png",
file_ref=None,
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
# Should be truncated with marker
assert "...[TRUNCATED]..." in result_content.base64_data
assert len(result_content.base64_data) < len(long_base64)
def test_preserves_text_content(self):
"""Text content should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
image_content = ImagePromptMessageContent(
format="png",
base64_data="test123",
mime_type="image/png",
file_ref="local:file-id",
)
message = UserPromptMessage(content=[text_content, image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 2
# Text content unchanged
assert result.content[0].data == "Hello!"
# Image content base64 cleared
assert result.content[1].base64_data == ""
class TestBuildContext:
"""Tests for build_context function."""
def test_excludes_system_messages(self):
"""System messages should be excluded from context."""
from core.model_runtime.entities.message_entities import SystemPromptMessage
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello!"),
]
context = build_context(messages, "Hi there!")
# Should have user message + assistant response, no system message
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
def test_appends_assistant_response(self):
"""Assistant response should be appended to context."""
messages = [UserPromptMessage(content="What is 2+2?")]
context = build_context(messages, "The answer is 4.")
assert len(context) == 2
assert context[1].content == "The answer is 4."
class TestRestoreMultimodalContentInMessages:
"""Tests for restore_multimodal_content_in_messages function."""
@patch("core.file.file_manager.restore_multimodal_content")
def test_restores_multimodal_content(self, mock_restore):
"""Should restore multimodal content in messages."""
# Setup mock
restored_content = ImagePromptMessageContent(
format="png",
base64_data="restored-base64",
mime_type="image/png",
file_ref="local:abc123",
)
mock_restore.return_value = restored_content
# Create message with truncated content
truncated_content = ImagePromptMessageContent(
format="png",
base64_data="",
mime_type="image/png",
file_ref="local:abc123",
)
message = UserPromptMessage(content=[truncated_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].base64_data == "restored-base64"
mock_restore.assert_called_once()
def test_passes_through_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello!")
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content == "Hello!"
def test_passes_through_text_content(self):
"""TextPromptMessageContent should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
message = UserPromptMessage(content=[text_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].data == "Hello!"