feat: add context file support

This commit is contained in:
Novice
2026-01-16 17:01:19 +08:00
parent e85e31773a
commit 18abc66585
7 changed files with 585 additions and 9 deletions

View File

@ -0,0 +1,182 @@
"""Tests for file_manager module, specifically multimodal content handling."""
from unittest.mock import patch
from core.file import File, FileTransferMethod, FileType
from core.file.file_manager import (
_encode_file_ref,
restore_multimodal_content,
to_prompt_message_content,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
class TestEncodeFileRef:
"""Tests for _encode_file_ref function."""
def test_encodes_local_file(self):
"""Local file should be encoded as 'local:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="abc123",
storage_key="key",
)
assert _encode_file_ref(file) == "local:abc123"
def test_encodes_tool_file(self):
"""Tool file should be encoded as 'tool:id'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id="xyz789",
storage_key="key",
)
assert _encode_file_ref(file) == "tool:xyz789"
def test_encodes_remote_url(self):
"""Remote URL should be encoded as 'remote:url'."""
file = File(
tenant_id="t",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.png",
storage_key="",
)
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
class TestToPromptMessageContent:
"""Tests for to_prompt_message_content function with file_ref field."""
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._get_encoded_string")
def test_includes_file_ref(self, mock_get_encoded, mock_config):
"""Generated content should include file_ref field."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_get_encoded.return_value = "base64data"
file = File(
id="test-message-file-id",
tenant_id="test-tenant",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-related-id",
remote_url=None,
extension=".png",
mime_type="image/png",
filename="test.png",
storage_key="test-key",
)
result = to_prompt_message_content(file)
assert isinstance(result, ImagePromptMessageContent)
assert result.file_ref == "local:test-related-id"
assert result.base64_data == "base64data"
class TestRestoreMultimodalContent:
"""Tests for restore_multimodal_content function."""
def test_returns_content_unchanged_when_no_file_ref(self):
"""Content without file_ref should pass through unchanged."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref=None,
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_data(self):
"""Content that already has base64_data should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
base64_data="existing-data",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "existing-data"
def test_returns_content_unchanged_when_already_has_url(self):
"""Content that already has url should not be reloaded."""
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/png",
file_ref="local:file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://example.com/image.png"
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._to_url")
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
"""Content should be restored from file_ref when url is empty (url mode)."""
mock_config.MULTIMODAL_SEND_FORMAT = "url"
mock_build_file.return_value = "mock_file"
mock_to_url.return_value = "https://restored-url.com/image.png"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.url == "https://restored-url.com/image.png"
mock_build_file.assert_called_once()
@patch("core.file.file_manager.dify_config")
@patch("core.file.file_manager._build_file_from_ref")
@patch("core.file.file_manager._get_encoded_string")
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
"""Content should be restored as base64 when in base64 mode."""
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
mock_build_file.return_value = "mock_file"
mock_get_encoded.return_value = "restored-base64-data"
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
result = restore_multimodal_content(content)
assert result.base64_data == "restored-base64-data"
mock_build_file.assert_called_once()
def test_handles_invalid_file_ref_gracefully(self):
"""Invalid file_ref format should be handled gracefully."""
content = ImagePromptMessageContent(
format="png",
base64_data="",
url="",
mime_type="image/png",
file_ref="invalid_format_no_colon",
)
result = restore_multimodal_content(content)
# Should return unchanged on error
assert result.base64_data == ""

View File

@ -0,0 +1,174 @@
"""Tests for llm_utils module, specifically multimodal content handling."""
import string
from unittest.mock import patch
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
TextPromptMessageContent,
UserPromptMessage,
)
from core.workflow.nodes.llm.llm_utils import (
_truncate_multimodal_content,
build_context,
restore_multimodal_content_in_messages,
)
class TestTruncateMultimodalContent:
"""Tests for _truncate_multimodal_content function."""
def test_returns_message_unchanged_for_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello, world!")
result = _truncate_multimodal_content(message)
assert result.content == "Hello, world!"
def test_returns_message_unchanged_for_none_content(self):
"""None content should pass through unchanged."""
message = UserPromptMessage(content=None)
result = _truncate_multimodal_content(message)
assert result.content is None
def test_clears_base64_when_file_ref_present(self):
"""When file_ref is present, base64_data and url should be cleared."""
image_content = ImagePromptMessageContent(
format="png",
base64_data=string.ascii_lowercase,
url="https://example.com/image.png",
mime_type="image/png",
filename="test.png",
file_ref="local:test-file-id",
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 1
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
assert result_content.base64_data == ""
assert result_content.url == ""
# file_ref should be preserved
assert result_content.file_ref == "local:test-file-id"
def test_truncates_base64_when_no_file_ref(self):
"""When file_ref is missing (legacy), base64_data should be truncated."""
long_base64 = "a" * 100
image_content = ImagePromptMessageContent(
format="png",
base64_data=long_base64,
mime_type="image/png",
filename="test.png",
file_ref=None,
)
message = UserPromptMessage(content=[image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
result_content = result.content[0]
assert isinstance(result_content, ImagePromptMessageContent)
# Should be truncated with marker
assert "...[TRUNCATED]..." in result_content.base64_data
assert len(result_content.base64_data) < len(long_base64)
def test_preserves_text_content(self):
"""Text content should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
image_content = ImagePromptMessageContent(
format="png",
base64_data="test123",
mime_type="image/png",
file_ref="local:file-id",
)
message = UserPromptMessage(content=[text_content, image_content])
result = _truncate_multimodal_content(message)
assert isinstance(result.content, list)
assert len(result.content) == 2
# Text content unchanged
assert result.content[0].data == "Hello!"
# Image content base64 cleared
assert result.content[1].base64_data == ""
class TestBuildContext:
"""Tests for build_context function."""
def test_excludes_system_messages(self):
"""System messages should be excluded from context."""
from core.model_runtime.entities.message_entities import SystemPromptMessage
messages = [
SystemPromptMessage(content="You are a helpful assistant."),
UserPromptMessage(content="Hello!"),
]
context = build_context(messages, "Hi there!")
# Should have user message + assistant response, no system message
assert len(context) == 2
assert context[0].content == "Hello!"
assert context[1].content == "Hi there!"
def test_appends_assistant_response(self):
"""Assistant response should be appended to context."""
messages = [UserPromptMessage(content="What is 2+2?")]
context = build_context(messages, "The answer is 4.")
assert len(context) == 2
assert context[1].content == "The answer is 4."
class TestRestoreMultimodalContentInMessages:
"""Tests for restore_multimodal_content_in_messages function."""
@patch("core.file.file_manager.restore_multimodal_content")
def test_restores_multimodal_content(self, mock_restore):
"""Should restore multimodal content in messages."""
# Setup mock
restored_content = ImagePromptMessageContent(
format="png",
base64_data="restored-base64",
mime_type="image/png",
file_ref="local:abc123",
)
mock_restore.return_value = restored_content
# Create message with truncated content
truncated_content = ImagePromptMessageContent(
format="png",
base64_data="",
mime_type="image/png",
file_ref="local:abc123",
)
message = UserPromptMessage(content=[truncated_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].base64_data == "restored-base64"
mock_restore.assert_called_once()
def test_passes_through_string_content(self):
"""String content should pass through unchanged."""
message = UserPromptMessage(content="Hello!")
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content == "Hello!"
def test_passes_through_text_content(self):
"""TextPromptMessageContent should pass through unchanged."""
text_content = TextPromptMessageContent(data="Hello!")
message = UserPromptMessage(content=[text_content])
result = restore_multimodal_content_in_messages([message])
assert len(result) == 1
assert result[0].content[0].data == "Hello!"