mirror of
https://github.com/langgenius/dify.git
synced 2026-05-04 09:28:04 +08:00
feat: add context file support
This commit is contained in:
182
api/tests/unit_tests/core/file/test_file_manager.py
Normal file
182
api/tests/unit_tests/core/file/test_file_manager.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Tests for file_manager module, specifically multimodal content handling."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.file.file_manager import (
|
||||
_encode_file_ref,
|
||||
restore_multimodal_content,
|
||||
to_prompt_message_content,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
class TestEncodeFileRef:
|
||||
"""Tests for _encode_file_ref function."""
|
||||
|
||||
def test_encodes_local_file(self):
|
||||
"""Local file should be encoded as 'local:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="abc123",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "local:abc123"
|
||||
|
||||
def test_encodes_tool_file(self):
|
||||
"""Tool file should be encoded as 'tool:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="xyz789",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "tool:xyz789"
|
||||
|
||||
def test_encodes_remote_url(self):
|
||||
"""Remote URL should be encoded as 'remote:url'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
storage_key="",
|
||||
)
|
||||
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
|
||||
|
||||
|
||||
class TestToPromptMessageContent:
|
||||
"""Tests for to_prompt_message_content function with file_ref field."""
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_includes_file_ref(self, mock_get_encoded, mock_config):
|
||||
"""Generated content should include file_ref field."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_get_encoded.return_value = "base64data"
|
||||
|
||||
file = File(
|
||||
id="test-message-file-id",
|
||||
tenant_id="test-tenant",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test-related-id",
|
||||
remote_url=None,
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
storage_key="test-key",
|
||||
)
|
||||
|
||||
result = to_prompt_message_content(file)
|
||||
|
||||
assert isinstance(result, ImagePromptMessageContent)
|
||||
assert result.file_ref == "local:test-related-id"
|
||||
assert result.base64_data == "base64data"
|
||||
|
||||
|
||||
class TestRestoreMultimodalContent:
|
||||
"""Tests for restore_multimodal_content function."""
|
||||
|
||||
def test_returns_content_unchanged_when_no_file_ref(self):
|
||||
"""Content without file_ref should pass through unchanged."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref=None,
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_data(self):
|
||||
"""Content that already has base64_data should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_url(self):
|
||||
"""Content that already has url should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._to_url")
|
||||
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
|
||||
"""Content should be restored from file_ref when url is empty (url mode)."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_to_url.return_value = "https://restored-url.com/image.png"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://restored-url.com/image.png"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
|
||||
"""Content should be restored as base64 when in base64 mode."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_get_encoded.return_value = "restored-base64-data"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "restored-base64-data"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
def test_handles_invalid_file_ref_gracefully(self):
|
||||
"""Invalid file_ref format should be handled gracefully."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
file_ref="invalid_format_no_colon",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
# Should return unchanged on error
|
||||
assert result.base64_data == ""
|
||||
174
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
174
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@ -0,0 +1,174 @@
|
||||
"""Tests for llm_utils module, specifically multimodal content handling."""
|
||||
|
||||
import string
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.workflow.nodes.llm.llm_utils import (
|
||||
_truncate_multimodal_content,
|
||||
build_context,
|
||||
restore_multimodal_content_in_messages,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncateMultimodalContent:
|
||||
"""Tests for _truncate_multimodal_content function."""
|
||||
|
||||
def test_returns_message_unchanged_for_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello, world!")
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content == "Hello, world!"
|
||||
|
||||
def test_returns_message_unchanged_for_none_content(self):
|
||||
"""None content should pass through unchanged."""
|
||||
message = UserPromptMessage(content=None)
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content is None
|
||||
|
||||
def test_clears_base64_when_file_ref_present(self):
|
||||
"""When file_ref is present, base64_data and url should be cleared."""
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=string.ascii_lowercase,
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
assert result_content.base64_data == ""
|
||||
assert result_content.url == ""
|
||||
# file_ref should be preserved
|
||||
assert result_content.file_ref == "local:test-file-id"
|
||||
|
||||
def test_truncates_base64_when_no_file_ref(self):
|
||||
"""When file_ref is missing (legacy), base64_data should be truncated."""
|
||||
long_base64 = "a" * 100
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=long_base64,
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref=None,
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
# Should be truncated with marker
|
||||
assert "...[TRUNCATED]..." in result_content.base64_data
|
||||
assert len(result_content.base64_data) < len(long_base64)
|
||||
|
||||
def test_preserves_text_content(self):
|
||||
"""Text content should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="test123",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[text_content, image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 2
|
||||
# Text content unchanged
|
||||
assert result.content[0].data == "Hello!"
|
||||
# Image content base64 cleared
|
||||
assert result.content[1].base64_data == ""
|
||||
|
||||
|
||||
class TestBuildContext:
|
||||
"""Tests for build_context function."""
|
||||
|
||||
def test_excludes_system_messages(self):
|
||||
"""System messages should be excluded from context."""
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello!"),
|
||||
]
|
||||
|
||||
context = build_context(messages, "Hi there!")
|
||||
|
||||
# Should have user message + assistant response, no system message
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
|
||||
def test_appends_assistant_response(self):
|
||||
"""Assistant response should be appended to context."""
|
||||
messages = [UserPromptMessage(content="What is 2+2?")]
|
||||
|
||||
context = build_context(messages, "The answer is 4.")
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[1].content == "The answer is 4."
|
||||
|
||||
|
||||
class TestRestoreMultimodalContentInMessages:
|
||||
"""Tests for restore_multimodal_content_in_messages function."""
|
||||
|
||||
@patch("core.file.file_manager.restore_multimodal_content")
|
||||
def test_restores_multimodal_content(self, mock_restore):
|
||||
"""Should restore multimodal content in messages."""
|
||||
# Setup mock
|
||||
restored_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="restored-base64",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
mock_restore.return_value = restored_content
|
||||
|
||||
# Create message with truncated content
|
||||
truncated_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
message = UserPromptMessage(content=[truncated_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].base64_data == "restored-base64"
|
||||
mock_restore.assert_called_once()
|
||||
|
||||
def test_passes_through_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello!")
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello!"
|
||||
|
||||
def test_passes_through_text_content(self):
|
||||
"""TextPromptMessageContent should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
message = UserPromptMessage(content=[text_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].data == "Hello!"
|
||||
Reference in New Issue
Block a user