mirror of
https://github.com/langgenius/dify.git
synced 2026-03-22 23:08:03 +08:00
feat: enhance download URL generation with optional filename parameter
Added support for an optional `download_filename` parameter in the `get_download_url` and `get_download_urls` methods across various storage classes. This allows users to specify a custom filename for downloads, improving user experience by enabling better file naming during downloads. Updated related methods and tests to accommodate this new functionality.
This commit is contained in:
@ -118,8 +118,14 @@ class Storage:
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
return self.storage_runner.scan(path, files=files, directories=directories)
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self.storage_runner.get_download_url(filename, expires_in)
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
return self.storage_runner.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
|
||||
|
||||
storage = Storage()
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from urllib.parse import quote
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient, Config
|
||||
@ -105,23 +106,73 @@ class AwsS3Storage(BaseStorage):
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a presigned download URL.
|
||||
|
||||
Args:
|
||||
filename: The S3 object key
|
||||
expires_in: URL validity duration in seconds
|
||||
download_filename: If provided, sets Content-Disposition header so browser
|
||||
downloads the file with this name instead of the S3 key.
|
||||
"""
|
||||
params: dict = {"Bucket": self.bucket_name, "Key": filename}
|
||||
if download_filename:
|
||||
# RFC 5987 / RFC 6266: Use both filename and filename* for compatibility.
|
||||
# filename* with UTF-8 encoding handles non-ASCII characters.
|
||||
encoded = quote(download_filename)
|
||||
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
|
||||
url: str = self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
Params=params,
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
return url
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
return [
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
ExpiresIn=expires_in,
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Generate presigned download URLs for multiple files.
|
||||
|
||||
Args:
|
||||
filenames: List of S3 object keys
|
||||
expires_in: URL validity duration in seconds
|
||||
download_filenames: If provided, must match len(filenames). Sets
|
||||
Content-Disposition for each file.
|
||||
"""
|
||||
if download_filenames is None:
|
||||
return [
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
for filename in filenames
|
||||
]
|
||||
|
||||
urls: list[str] = []
|
||||
for filename, download_filename in zip(filenames, download_filenames, strict=True):
|
||||
params: dict = {"Bucket": self.bucket_name, "Key": filename}
|
||||
if download_filename:
|
||||
encoded = quote(download_filename)
|
||||
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
|
||||
urls.append(
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params=params,
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
)
|
||||
for filename in filenames
|
||||
]
|
||||
return urls
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
url: str = self.client.generate_presigned_url(
|
||||
|
||||
@ -39,7 +39,13 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support scanning")
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a pre-signed URL for downloading a file.
|
||||
|
||||
@ -49,6 +55,9 @@ class BaseStorage(ABC):
|
||||
Args:
|
||||
filename: The file path/key in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filename: If provided, the browser will use this as the downloaded
|
||||
file name instead of the storage key. Implemented via response header
|
||||
override (e.g., Content-Disposition) where supported.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string
|
||||
@ -58,9 +67,21 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate pre-signed URLs for downloading multiple files.
|
||||
|
||||
Args:
|
||||
filenames: List of file paths/keys in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filenames: If provided, must match len(filenames). Each element
|
||||
specifies the download filename for the corresponding file.
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Storage wrapper that caches presigned download URLs."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from itertools import starmap
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@ -39,23 +41,31 @@ class CachedPresignStorage(StorageWrapper):
|
||||
super().delete(filename)
|
||||
self.invalidate([filename])
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get a presigned download URL, using cache when available.
|
||||
|
||||
Args:
|
||||
filename: The file path/key in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filename: If provided, the browser will use this as the downloaded
|
||||
file name. Cache keys include this value to avoid conflicts.
|
||||
|
||||
Returns:
|
||||
Presigned URL string
|
||||
"""
|
||||
cache_key = self._cache_key(filename)
|
||||
cache_key = self._cache_key(filename, download_filename)
|
||||
|
||||
cached = self._get_cached(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
url = self._storage.get_download_url(filename, expires_in)
|
||||
url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
self._set_cached(cache_key, url, expires_in)
|
||||
|
||||
return url
|
||||
@ -64,12 +74,16 @@ class CachedPresignStorage(StorageWrapper):
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Batch get download URLs with cache.
|
||||
|
||||
Args:
|
||||
filenames: List of file paths/keys in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filenames: If provided, must match len(filenames). Each element
|
||||
specifies the download filename for the corresponding file.
|
||||
|
||||
Returns:
|
||||
List of presigned URLs in the same order as filenames
|
||||
@ -77,22 +91,32 @@ class CachedPresignStorage(StorageWrapper):
|
||||
if not filenames:
|
||||
return []
|
||||
|
||||
cache_keys = [self._cache_key(f) for f in filenames]
|
||||
# Build cache keys including download_filename for uniqueness
|
||||
if download_filenames is None:
|
||||
cache_keys = [self._cache_key(f, None) for f in filenames]
|
||||
else:
|
||||
cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True)))
|
||||
|
||||
cached_values = self._get_cached_batch(cache_keys)
|
||||
|
||||
# Build results list, tracking which indices need fetching
|
||||
results: list[str | None] = list(cached_values)
|
||||
uncached_indices: list[int] = []
|
||||
uncached_filenames: list[str] = []
|
||||
uncached_download_filenames: list[str | None] = []
|
||||
|
||||
for i, (filename, cached) in enumerate(zip(filenames, cached_values)):
|
||||
if not cached:
|
||||
uncached_indices.append(i)
|
||||
uncached_filenames.append(filename)
|
||||
uncached_download_filenames.append(download_filenames[i] if download_filenames else None)
|
||||
|
||||
# Batch fetch uncached URLs from storage
|
||||
if uncached_filenames:
|
||||
uncached_urls = [self._storage.get_download_url(f, expires_in) for f in uncached_filenames]
|
||||
uncached_urls = [
|
||||
self._storage.get_download_url(f, expires_in, download_filename=df)
|
||||
for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True)
|
||||
]
|
||||
|
||||
# Fill results at correct positions
|
||||
for idx, url in zip(uncached_indices, uncached_urls):
|
||||
@ -119,8 +143,19 @@ class CachedPresignStorage(StorageWrapper):
|
||||
except Exception:
|
||||
logger.warning("Failed to invalidate presign cache", exc_info=True)
|
||||
|
||||
def _cache_key(self, filename: str) -> str:
|
||||
"""Generate cache key for a filename."""
|
||||
def _cache_key(self, filename: str, download_filename: str | None = None) -> str:
|
||||
"""Generate cache key for a filename.
|
||||
|
||||
When download_filename is provided, its hash is appended to the key to ensure
|
||||
different download names for the same storage key get separate cache entries.
|
||||
We use a hash (truncated MD5) instead of the raw string because:
|
||||
- download_filename may contain special characters unsafe for Redis keys
|
||||
- Hash collisions only cause a cache miss, no functional impact
|
||||
"""
|
||||
if download_filename:
|
||||
# Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness
|
||||
name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16]
|
||||
return f"{self._cache_key_prefix}:{filename}::{name_hash}"
|
||||
return f"{self._cache_key_prefix}:{filename}"
|
||||
|
||||
def _compute_ttl(self, expires_in: int) -> int:
|
||||
|
||||
@ -28,23 +28,40 @@ class FilePresignStorage(StorageWrapper):
|
||||
Otherwise, generates ticket-based URLs for both download and upload operations.
|
||||
"""
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get a presigned download URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return self._storage.get_download_url(filename, expires_in)
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename)
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get presigned download URLs for multiple files."""
|
||||
try:
|
||||
return self._storage.get_download_urls(filenames, expires_in)
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
if download_filenames is None:
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
return [
|
||||
StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df)
|
||||
for f, df in zip(filenames, download_filenames, strict=True)
|
||||
]
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
"""Get a presigned upload URL, falling back to ticket URL if not supported."""
|
||||
|
||||
@ -44,11 +44,23 @@ class StorageWrapper(BaseStorage):
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
return self._storage.scan(path, files=files, directories=directories)
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_download_url(filename, expires_in)
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
return self._storage.get_download_urls(filenames, expires_in)
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_upload_url(filename, expires_in)
|
||||
|
||||
@ -366,7 +366,7 @@ class AppAssetService:
|
||||
|
||||
asset_storage = AppAssetService.get_storage()
|
||||
key = AssetPaths.draft(app_model.tenant_id, app_model.id, node_id)
|
||||
return asset_storage.get_download_url(key, expires_in)
|
||||
return asset_storage.get_download_url(key, expires_in, download_filename=node.name)
|
||||
|
||||
@staticmethod
|
||||
def get_source_zip_bytes(tenant_id: str, app_id: str, workflow_id: str) -> bytes | None:
|
||||
|
||||
@ -145,8 +145,9 @@ class AppBundleService:
|
||||
archive = zs.zip(src="bundle_root", include_base=False)
|
||||
zs.upload(archive, upload_url)
|
||||
|
||||
download_url = asset_storage.get_download_url(export_key, expires_in)
|
||||
return BundleExportResult(download_url=download_url, filename=f"{safe_name}.zip")
|
||||
bundle_filename = f"{safe_name}.zip"
|
||||
download_url = asset_storage.get_download_url(export_key, expires_in, download_filename=bundle_filename)
|
||||
return BundleExportResult(download_url=download_url, filename=bundle_filename)
|
||||
|
||||
# ========== Import ==========
|
||||
|
||||
|
||||
@ -32,10 +32,22 @@ class DummyStorage(BaseStorage):
|
||||
def delete(self, filename: str):
|
||||
return None
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from unittest.mock import Mock
|
||||
import hashlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -19,12 +20,16 @@ class TestCachedPresignStorage:
|
||||
return Mock()
|
||||
|
||||
@pytest.fixture
|
||||
def cached_storage(self, mock_storage):
|
||||
def cached_storage(self, mock_storage, mock_redis):
|
||||
"""Create CachedPresignStorage with mocks."""
|
||||
return CachedPresignStorage(
|
||||
storage=mock_storage,
|
||||
cache_key_prefix="test_prefix",
|
||||
)
|
||||
with patch("extensions.storage.cached_presign_storage.redis_client", mock_redis):
|
||||
storage = CachedPresignStorage(
|
||||
storage=mock_storage,
|
||||
cache_key_prefix="test_prefix",
|
||||
)
|
||||
# Inject mock_redis after creation for tests to verify calls
|
||||
storage._redis = mock_redis
|
||||
yield storage
|
||||
|
||||
def test_get_download_url_returns_cached_on_hit(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test that cached URL is returned when cache hit occurs."""
|
||||
@ -46,7 +51,7 @@ class TestCachedPresignStorage:
|
||||
|
||||
assert result == "https://new-url.com/file.txt"
|
||||
mock_redis.mget.assert_called_once_with(["test_prefix:path/to/file.txt"])
|
||||
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600)
|
||||
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None)
|
||||
mock_redis.setex.assert_called_once()
|
||||
call_args = mock_redis.setex.call_args
|
||||
assert call_args[0][0] == "test_prefix:path/to/file.txt"
|
||||
@ -61,7 +66,7 @@ class TestCachedPresignStorage:
|
||||
result = cached_storage.get_download_urls(filenames, expires_in=3600)
|
||||
|
||||
assert result == ["https://cached1.com", "https://new.com", "https://cached2.com"]
|
||||
mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600)
|
||||
mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600, download_filename=None)
|
||||
# Verify pipeline was used for batch cache write
|
||||
mock_redis.pipeline.assert_called_once()
|
||||
mock_redis.pipeline().execute.assert_called_once()
|
||||
@ -114,7 +119,7 @@ class TestCachedPresignStorage:
|
||||
result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600)
|
||||
|
||||
assert result == "https://new-url.com/file.txt"
|
||||
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600)
|
||||
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None)
|
||||
|
||||
def test_graceful_degradation_on_redis_setex_error(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test that URL is still returned when Redis setex fails."""
|
||||
@ -177,28 +182,45 @@ class TestCachedPresignStorage:
|
||||
key = cached_storage._cache_key("path/to/file.txt")
|
||||
assert key == "test_prefix:path/to/file.txt"
|
||||
|
||||
def test_cached_value_decoded_from_bytes(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test that bytes cached values are decoded to strings."""
|
||||
mock_redis.mget.return_value = [b"https://cached-url.com"]
|
||||
def test_cache_key_with_download_filename(self, cached_storage):
|
||||
"""Test cache key includes hashed download_filename when provided."""
|
||||
key = cached_storage._cache_key("path/to/file.txt", "custom_name.txt")
|
||||
# download_filename is hashed (first 16 chars of MD5 hex digest)
|
||||
expected_hash = hashlib.md5(b"custom_name.txt").hexdigest()[:16]
|
||||
assert key == f"test_prefix:path/to/file.txt::{expected_hash}"
|
||||
|
||||
result = cached_storage.get_download_url("file.txt")
|
||||
def test_get_download_url_with_download_filename(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test that download_filename is passed to storage and affects cache key."""
|
||||
mock_redis.mget.return_value = [None]
|
||||
mock_storage.get_download_url.return_value = "https://new-url.com/file.txt"
|
||||
|
||||
assert result == "https://cached-url.com"
|
||||
assert isinstance(result, str)
|
||||
result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600, download_filename="custom.txt")
|
||||
|
||||
def test_cached_value_decoded_from_bytearray(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test that bytearray cached values are decoded to strings."""
|
||||
mock_redis.mget.return_value = [bytearray(b"https://cached-url.com")]
|
||||
assert result == "https://new-url.com/file.txt"
|
||||
expected_hash = hashlib.md5(b"custom.txt").hexdigest()[:16]
|
||||
mock_redis.mget.assert_called_once_with([f"test_prefix:path/to/file.txt::{expected_hash}"])
|
||||
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename="custom.txt")
|
||||
|
||||
result = cached_storage.get_download_url("file.txt")
|
||||
def test_get_download_urls_with_download_filenames(self, cached_storage, mock_storage, mock_redis):
|
||||
"""Test batch URL retrieval with download_filenames."""
|
||||
mock_redis.mget.return_value = [None, None]
|
||||
mock_storage.get_download_url.side_effect = ["https://url1.com", "https://url2.com"]
|
||||
|
||||
assert result == "https://cached-url.com"
|
||||
assert isinstance(result, str)
|
||||
filenames = ["file1.txt", "file2.txt"]
|
||||
download_filenames = ["custom1.txt", "custom2.txt"]
|
||||
result = cached_storage.get_download_urls(filenames, expires_in=3600, download_filenames=download_filenames)
|
||||
|
||||
def test_default_cache_key_prefix(self, mock_storage):
|
||||
"""Test default cache key prefix is used when not specified."""
|
||||
storage = CachedPresignStorage(
|
||||
storage=mock_storage,
|
||||
assert result == ["https://url1.com", "https://url2.com"]
|
||||
# Verify cache keys include hashed download_filenames
|
||||
hash1 = hashlib.md5(b"custom1.txt").hexdigest()[:16]
|
||||
hash2 = hashlib.md5(b"custom2.txt").hexdigest()[:16]
|
||||
mock_redis.mget.assert_called_once_with(
|
||||
[
|
||||
f"test_prefix:file1.txt::{hash1}",
|
||||
f"test_prefix:file2.txt::{hash2}",
|
||||
]
|
||||
)
|
||||
key = storage._cache_key("file.txt")
|
||||
assert key == "presign_cache:file.txt"
|
||||
# Verify storage calls include download_filename
|
||||
assert mock_storage.get_download_url.call_count == 2
|
||||
mock_storage.get_download_url.assert_any_call("file1.txt", 3600, download_filename="custom1.txt")
|
||||
mock_storage.get_download_url.assert_any_call("file2.txt", 3600, download_filename="custom2.txt")
|
||||
|
||||
Reference in New Issue
Block a user