feat: enhance download URL generation with optional filename parameter

Added support for an optional `download_filename` parameter in the `get_download_url` and `get_download_urls` methods across various storage classes. This allows users to specify a custom filename for downloads, improving user experience by enabling better file naming during downloads. Updated related methods and tests to accommodate this new functionality.
This commit is contained in:
Harry
2026-02-03 14:40:14 +08:00
parent 5441b9c3ad
commit 49befa6d3f
10 changed files with 240 additions and 63 deletions

View File

@ -118,8 +118,14 @@ class Storage:
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
return self.storage_runner.scan(path, files=files, directories=directories)
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
return self.storage_runner.get_download_url(filename, expires_in)
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
return self.storage_runner.get_download_url(filename, expires_in, download_filename=download_filename)
storage = Storage()

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from urllib.parse import quote
import boto3
from botocore.client import BaseClient, Config
@ -105,23 +106,73 @@ class AwsS3Storage(BaseStorage):
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Generate a presigned download URL.
Args:
filename: The S3 object key
expires_in: URL validity duration in seconds
download_filename: If provided, sets Content-Disposition header so browser
downloads the file with this name instead of the S3 key.
"""
params: dict = {"Bucket": self.bucket_name, "Key": filename}
if download_filename:
# RFC 5987 / RFC 6266: Use both filename and filename* for compatibility.
# filename* with UTF-8 encoding handles non-ASCII characters.
encoded = quote(download_filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
url: str = self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
Params=params,
ExpiresIn=expires_in,
)
return url
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
return [
self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Generate presigned download URLs for multiple files.
Args:
filenames: List of S3 object keys
expires_in: URL validity duration in seconds
download_filenames: If provided, must match len(filenames). Sets
Content-Disposition for each file.
"""
if download_filenames is None:
return [
self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
)
for filename in filenames
]
urls: list[str] = []
for filename, download_filename in zip(filenames, download_filenames, strict=True):
params: dict = {"Bucket": self.bucket_name, "Key": filename}
if download_filename:
encoded = quote(download_filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
urls.append(
self.client.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=expires_in,
)
)
for filename in filenames
]
return urls
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
url: str = self.client.generate_presigned_url(

View File

@ -39,7 +39,13 @@ class BaseStorage(ABC):
"""
raise NotImplementedError("This storage backend doesn't support scanning")
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""
Generate a pre-signed URL for downloading a file.
@ -49,6 +55,9 @@ class BaseStorage(ABC):
Args:
filename: The file path/key in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filename: If provided, the browser will use this as the downloaded
file name instead of the storage key. Implemented via response header
override (e.g., Content-Disposition) where supported.
Returns:
Pre-signed URL string
@ -58,9 +67,21 @@ class BaseStorage(ABC):
"""
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""
Generate pre-signed URLs for downloading multiple files.
Args:
filenames: List of file paths/keys in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filenames: If provided, must match len(filenames). Each element
specifies the download filename for the corresponding file.
"""
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")

View File

@ -1,6 +1,8 @@
"""Storage wrapper that caches presigned download URLs."""
import hashlib
import logging
from itertools import starmap
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
@ -39,23 +41,31 @@ class CachedPresignStorage(StorageWrapper):
super().delete(filename)
self.invalidate([filename])
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Get a presigned download URL, using cache when available.
Args:
filename: The file path/key in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filename: If provided, the browser will use this as the downloaded
file name. Cache keys include this value to avoid conflicts.
Returns:
Presigned URL string
"""
cache_key = self._cache_key(filename)
cache_key = self._cache_key(filename, download_filename)
cached = self._get_cached(cache_key)
if cached:
return cached
url = self._storage.get_download_url(filename, expires_in)
url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
self._set_cached(cache_key, url, expires_in)
return url
@ -64,12 +74,16 @@ class CachedPresignStorage(StorageWrapper):
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Batch get download URLs with cache.
Args:
filenames: List of file paths/keys in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filenames: If provided, must match len(filenames). Each element
specifies the download filename for the corresponding file.
Returns:
List of presigned URLs in the same order as filenames
@ -77,22 +91,32 @@ class CachedPresignStorage(StorageWrapper):
if not filenames:
return []
cache_keys = [self._cache_key(f) for f in filenames]
# Build cache keys including download_filename for uniqueness
if download_filenames is None:
cache_keys = [self._cache_key(f, None) for f in filenames]
else:
cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True)))
cached_values = self._get_cached_batch(cache_keys)
# Build results list, tracking which indices need fetching
results: list[str | None] = list(cached_values)
uncached_indices: list[int] = []
uncached_filenames: list[str] = []
uncached_download_filenames: list[str | None] = []
for i, (filename, cached) in enumerate(zip(filenames, cached_values)):
if not cached:
uncached_indices.append(i)
uncached_filenames.append(filename)
uncached_download_filenames.append(download_filenames[i] if download_filenames else None)
# Batch fetch uncached URLs from storage
if uncached_filenames:
uncached_urls = [self._storage.get_download_url(f, expires_in) for f in uncached_filenames]
uncached_urls = [
self._storage.get_download_url(f, expires_in, download_filename=df)
for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True)
]
# Fill results at correct positions
for idx, url in zip(uncached_indices, uncached_urls):
@ -119,8 +143,19 @@ class CachedPresignStorage(StorageWrapper):
except Exception:
logger.warning("Failed to invalidate presign cache", exc_info=True)
def _cache_key(self, filename: str) -> str:
"""Generate cache key for a filename."""
def _cache_key(self, filename: str, download_filename: str | None = None) -> str:
"""Generate cache key for a filename.
When download_filename is provided, its hash is appended to the key to ensure
different download names for the same storage key get separate cache entries.
We use a hash (truncated MD5) instead of the raw string because:
- download_filename may contain special characters unsafe for Redis keys
- Hash collisions only cause a cache miss, no functional impact
"""
if download_filename:
# Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness
name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16]
return f"{self._cache_key_prefix}:{filename}::{name_hash}"
return f"{self._cache_key_prefix}:{filename}"
def _compute_ttl(self, expires_in: int) -> int:

View File

@ -28,23 +28,40 @@ class FilePresignStorage(StorageWrapper):
Otherwise, generates ticket-based URLs for both download and upload operations.
"""
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Get a presigned download URL, falling back to ticket URL if not supported."""
try:
return self._storage.get_download_url(filename, expires_in)
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
except NotImplementedError:
from services.storage_ticket_service import StorageTicketService
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename)
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Get presigned download URLs for multiple files."""
try:
return self._storage.get_download_urls(filenames, expires_in)
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
except NotImplementedError:
from services.storage_ticket_service import StorageTicketService
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
if download_filenames is None:
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
return [
StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df)
for f, df in zip(filenames, download_filenames, strict=True)
]
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
"""Get a presigned upload URL, falling back to ticket URL if not supported."""

View File

@ -44,11 +44,23 @@ class StorageWrapper(BaseStorage):
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
return self._storage.scan(path, files=files, directories=directories)
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
return self._storage.get_download_url(filename, expires_in)
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
return self._storage.get_download_urls(filenames, expires_in)
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
return self._storage.get_upload_url(filename, expires_in)

View File

@ -366,7 +366,7 @@ class AppAssetService:
asset_storage = AppAssetService.get_storage()
key = AssetPaths.draft(app_model.tenant_id, app_model.id, node_id)
return asset_storage.get_download_url(key, expires_in)
return asset_storage.get_download_url(key, expires_in, download_filename=node.name)
@staticmethod
def get_source_zip_bytes(tenant_id: str, app_id: str, workflow_id: str) -> bytes | None:

View File

@ -145,8 +145,9 @@ class AppBundleService:
archive = zs.zip(src="bundle_root", include_base=False)
zs.upload(archive, upload_url)
download_url = asset_storage.get_download_url(export_key, expires_in)
return BundleExportResult(download_url=download_url, filename=f"{safe_name}.zip")
bundle_filename = f"{safe_name}.zip"
download_url = asset_storage.get_download_url(export_key, expires_in, download_filename=bundle_filename)
return BundleExportResult(download_url=download_url, filename=bundle_filename)
# ========== Import ==========

View File

@ -32,10 +32,22 @@ class DummyStorage(BaseStorage):
def delete(self, filename: str):
return None
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
raise NotImplementedError
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
raise NotImplementedError
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:

View File

@ -1,4 +1,5 @@
from unittest.mock import Mock
import hashlib
from unittest.mock import Mock, patch
import pytest
@ -19,12 +20,16 @@ class TestCachedPresignStorage:
return Mock()
@pytest.fixture
def cached_storage(self, mock_storage):
def cached_storage(self, mock_storage, mock_redis):
"""Create CachedPresignStorage with mocks."""
return CachedPresignStorage(
storage=mock_storage,
cache_key_prefix="test_prefix",
)
with patch("extensions.storage.cached_presign_storage.redis_client", mock_redis):
storage = CachedPresignStorage(
storage=mock_storage,
cache_key_prefix="test_prefix",
)
# Inject mock_redis after creation for tests to verify calls
storage._redis = mock_redis
yield storage
def test_get_download_url_returns_cached_on_hit(self, cached_storage, mock_storage, mock_redis):
"""Test that cached URL is returned when cache hit occurs."""
@ -46,7 +51,7 @@ class TestCachedPresignStorage:
assert result == "https://new-url.com/file.txt"
mock_redis.mget.assert_called_once_with(["test_prefix:path/to/file.txt"])
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600)
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None)
mock_redis.setex.assert_called_once()
call_args = mock_redis.setex.call_args
assert call_args[0][0] == "test_prefix:path/to/file.txt"
@ -61,7 +66,7 @@ class TestCachedPresignStorage:
result = cached_storage.get_download_urls(filenames, expires_in=3600)
assert result == ["https://cached1.com", "https://new.com", "https://cached2.com"]
mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600)
mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600, download_filename=None)
# Verify pipeline was used for batch cache write
mock_redis.pipeline.assert_called_once()
mock_redis.pipeline().execute.assert_called_once()
@ -114,7 +119,7 @@ class TestCachedPresignStorage:
result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600)
assert result == "https://new-url.com/file.txt"
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600)
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None)
def test_graceful_degradation_on_redis_setex_error(self, cached_storage, mock_storage, mock_redis):
"""Test that URL is still returned when Redis setex fails."""
@ -177,28 +182,45 @@ class TestCachedPresignStorage:
key = cached_storage._cache_key("path/to/file.txt")
assert key == "test_prefix:path/to/file.txt"
def test_cached_value_decoded_from_bytes(self, cached_storage, mock_storage, mock_redis):
"""Test that bytes cached values are decoded to strings."""
mock_redis.mget.return_value = [b"https://cached-url.com"]
def test_cache_key_with_download_filename(self, cached_storage):
"""Test cache key includes hashed download_filename when provided."""
key = cached_storage._cache_key("path/to/file.txt", "custom_name.txt")
# download_filename is hashed (first 16 chars of MD5 hex digest)
expected_hash = hashlib.md5(b"custom_name.txt").hexdigest()[:16]
assert key == f"test_prefix:path/to/file.txt::{expected_hash}"
result = cached_storage.get_download_url("file.txt")
def test_get_download_url_with_download_filename(self, cached_storage, mock_storage, mock_redis):
"""Test that download_filename is passed to storage and affects cache key."""
mock_redis.mget.return_value = [None]
mock_storage.get_download_url.return_value = "https://new-url.com/file.txt"
assert result == "https://cached-url.com"
assert isinstance(result, str)
result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600, download_filename="custom.txt")
def test_cached_value_decoded_from_bytearray(self, cached_storage, mock_storage, mock_redis):
"""Test that bytearray cached values are decoded to strings."""
mock_redis.mget.return_value = [bytearray(b"https://cached-url.com")]
assert result == "https://new-url.com/file.txt"
expected_hash = hashlib.md5(b"custom.txt").hexdigest()[:16]
mock_redis.mget.assert_called_once_with([f"test_prefix:path/to/file.txt::{expected_hash}"])
mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename="custom.txt")
result = cached_storage.get_download_url("file.txt")
def test_get_download_urls_with_download_filenames(self, cached_storage, mock_storage, mock_redis):
"""Test batch URL retrieval with download_filenames."""
mock_redis.mget.return_value = [None, None]
mock_storage.get_download_url.side_effect = ["https://url1.com", "https://url2.com"]
assert result == "https://cached-url.com"
assert isinstance(result, str)
filenames = ["file1.txt", "file2.txt"]
download_filenames = ["custom1.txt", "custom2.txt"]
result = cached_storage.get_download_urls(filenames, expires_in=3600, download_filenames=download_filenames)
def test_default_cache_key_prefix(self, mock_storage):
"""Test default cache key prefix is used when not specified."""
storage = CachedPresignStorage(
storage=mock_storage,
assert result == ["https://url1.com", "https://url2.com"]
# Verify cache keys include hashed download_filenames
hash1 = hashlib.md5(b"custom1.txt").hexdigest()[:16]
hash2 = hashlib.md5(b"custom2.txt").hexdigest()[:16]
mock_redis.mget.assert_called_once_with(
[
f"test_prefix:file1.txt::{hash1}",
f"test_prefix:file2.txt::{hash2}",
]
)
key = storage._cache_key("file.txt")
assert key == "presign_cache:file.txt"
# Verify storage calls include download_filename
assert mock_storage.get_download_url.call_count == 2
mock_storage.get_download_url.assert_any_call("file1.txt", 3600, download_filename="custom1.txt")
mock_storage.get_download_url.assert_any_call("file2.txt", 3600, download_filename="custom2.txt")