mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
feat: add context file support
This commit is contained in:
@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
@ -10,7 +11,10 @@ from core.model_runtime.entities import (
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
)
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@ -18,6 +22,8 @@ from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
@ -89,6 +95,8 @@ def to_prompt_message_content(
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
|
||||
"file_ref": _encode_file_ref(f),
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
@ -96,6 +104,17 @@ def to_prompt_message_content(
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def _encode_file_ref(f: File) -> str | None:
|
||||
"""Encode file reference as 'transfer_method:id_or_url' string."""
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return f"remote:{f.remote_url}" if f.remote_url else None
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return f"local:{f.related_id}" if f.related_id else None
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return f"tool:{f.related_id}" if f.related_id else None
|
||||
return None
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
@ -164,3 +183,128 @@ def _to_url(f: File, /):
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def restore_multimodal_content(
|
||||
content: MultiModalPromptMessageContent,
|
||||
) -> MultiModalPromptMessageContent:
|
||||
"""
|
||||
Restore base64_data or url for multimodal content from file_ref.
|
||||
|
||||
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
|
||||
|
||||
Args:
|
||||
content: MultiModalPromptMessageContent with file_ref field
|
||||
|
||||
Returns:
|
||||
MultiModalPromptMessageContent with restored base64_data or url
|
||||
"""
|
||||
# Skip if no file reference or content already has data
|
||||
if not content.file_ref:
|
||||
return content
|
||||
if content.base64_data or content.url:
|
||||
return content
|
||||
|
||||
try:
|
||||
file = _build_file_from_ref(
|
||||
file_ref=content.file_ref,
|
||||
file_format=content.format,
|
||||
mime_type=content.mime_type,
|
||||
filename=content.filename,
|
||||
)
|
||||
if not file:
|
||||
return content
|
||||
|
||||
# Restore content based on config
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
|
||||
restored_base64 = _get_encoded_string(file)
|
||||
return content.model_copy(update={"base64_data": restored_base64})
|
||||
else:
|
||||
restored_url = _to_url(file)
|
||||
return content.model_copy(update={"url": restored_url})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore multimodal content: %s", e)
|
||||
return content
|
||||
|
||||
|
||||
def _build_file_from_ref(
|
||||
file_ref: str,
|
||||
file_format: str | None,
|
||||
mime_type: str | None,
|
||||
filename: str | None,
|
||||
) -> File | None:
|
||||
"""
|
||||
Build a File object from encoded file_ref string.
|
||||
|
||||
Args:
|
||||
file_ref: Encoded reference "transfer_method:id_or_url"
|
||||
file_format: The file format/extension (without dot)
|
||||
mime_type: The mime type
|
||||
filename: The filename
|
||||
|
||||
Returns:
|
||||
File object with storage_key loaded, or None if not found
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
# Parse file_ref: "method:value"
|
||||
if ":" not in file_ref:
|
||||
logger.warning("Invalid file_ref format: %s", file_ref)
|
||||
return None
|
||||
|
||||
method, value = file_ref.split(":", 1)
|
||||
extension = f".{file_format}" if file_format else None
|
||||
|
||||
if method == "remote":
|
||||
return File(
|
||||
tenant_id="",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
filename=filename,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Query database for storage_key
|
||||
with Session(db.engine) as session:
|
||||
if method == "local":
|
||||
stmt = select(UploadFile).where(UploadFile.id == value)
|
||||
upload_file = session.scalar(stmt)
|
||||
if upload_file:
|
||||
return File(
|
||||
tenant_id=upload_file.tenant_id,
|
||||
type=FileType(upload_file.extension)
|
||||
if hasattr(FileType, upload_file.extension.upper())
|
||||
else FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=value,
|
||||
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
|
||||
mime_type=mime_type or upload_file.mime_type,
|
||||
filename=filename or upload_file.name,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
elif method == "tool":
|
||||
stmt = select(ToolFile).where(ToolFile.id == value)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file:
|
||||
return File(
|
||||
tenant_id=tool_file.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type or tool_file.mimetype,
|
||||
filename=filename or tool_file.name,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
logger.warning("File not found for file_ref: %s", file_ref)
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user