feat: add context file support

This commit is contained in:
Novice
2026-01-16 17:01:19 +08:00
parent e85e31773a
commit 18abc66585
7 changed files with 585 additions and 9 deletions

View File

@ -1,4 +1,5 @@
import base64
import logging
from collections.abc import Mapping
from configs import dify_config
@ -10,7 +11,10 @@ from core.model_runtime.entities import (
TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.model_runtime.entities.message_entities import (
MultiModalPromptMessageContent,
PromptMessageContentUnionTypes,
)
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
@ -18,6 +22,8 @@ from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
logger = logging.getLogger(__name__)
def get_attr(*, file: File, attr: FileAttribute):
match attr:
@ -89,6 +95,8 @@ def to_prompt_message_content(
"format": f.extension.removeprefix("."),
"mime_type": f.mime_type,
"filename": f.filename or "",
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
"file_ref": _encode_file_ref(f),
}
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
@ -96,6 +104,17 @@ def to_prompt_message_content(
return prompt_class_map[f.type].model_validate(params)
def _encode_file_ref(f: File) -> str | None:
"""Encode file reference as 'transfer_method:id_or_url' string."""
if f.transfer_method == FileTransferMethod.REMOTE_URL:
return f"remote:{f.remote_url}" if f.remote_url else None
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
return f"local:{f.related_id}" if f.related_id else None
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
return f"tool:{f.related_id}" if f.related_id else None
return None
def download(f: File, /):
if f.transfer_method in (
FileTransferMethod.TOOL_FILE,
@ -164,3 +183,128 @@ def _to_url(f: File, /):
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
def restore_multimodal_content(
content: MultiModalPromptMessageContent,
) -> MultiModalPromptMessageContent:
"""
Restore base64_data or url for multimodal content from file_ref.
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
Args:
content: MultiModalPromptMessageContent with file_ref field
Returns:
MultiModalPromptMessageContent with restored base64_data or url
"""
# Skip if no file reference or content already has data
if not content.file_ref:
return content
if content.base64_data or content.url:
return content
try:
file = _build_file_from_ref(
file_ref=content.file_ref,
file_format=content.format,
mime_type=content.mime_type,
filename=content.filename,
)
if not file:
return content
# Restore content based on config
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
restored_base64 = _get_encoded_string(file)
return content.model_copy(update={"base64_data": restored_base64})
else:
restored_url = _to_url(file)
return content.model_copy(update={"url": restored_url})
except Exception as e:
logger.warning("Failed to restore multimodal content: %s", e)
return content
def _build_file_from_ref(
file_ref: str,
file_format: str | None,
mime_type: str | None,
filename: str | None,
) -> File | None:
"""
Build a File object from encoded file_ref string.
Args:
file_ref: Encoded reference "transfer_method:id_or_url"
file_format: The file format/extension (without dot)
mime_type: The mime type
filename: The filename
Returns:
File object with storage_key loaded, or None if not found
"""
from sqlalchemy import select
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.model import UploadFile
from models.tools import ToolFile
# Parse file_ref: "method:value"
if ":" not in file_ref:
logger.warning("Invalid file_ref format: %s", file_ref)
return None
method, value = file_ref.split(":", 1)
extension = f".{file_format}" if file_format else None
if method == "remote":
return File(
tenant_id="",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url=value,
extension=extension,
mime_type=mime_type,
filename=filename,
storage_key="",
)
# Query database for storage_key
with Session(db.engine) as session:
if method == "local":
stmt = select(UploadFile).where(UploadFile.id == value)
upload_file = session.scalar(stmt)
if upload_file:
return File(
tenant_id=upload_file.tenant_id,
type=FileType(upload_file.extension)
if hasattr(FileType, upload_file.extension.upper())
else FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id=value,
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
mime_type=mime_type or upload_file.mime_type,
filename=filename or upload_file.name,
storage_key=upload_file.key,
)
elif method == "tool":
stmt = select(ToolFile).where(ToolFile.id == value)
tool_file = session.scalar(stmt)
if tool_file:
return File(
tenant_id=tool_file.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=value,
extension=extension,
mime_type=mime_type or tool_file.mimetype,
filename=filename or tool_file.name,
storage_key=tool_file.file_key,
)
logger.warning("File not found for file_ref: %s", file_ref)
return None