feat: enhance download URL generation with optional filename parameter

Added support for an optional `download_filename` parameter in the `get_download_url` and `get_download_urls` methods across various storage classes. This allows users to specify a custom filename for downloads, improving user experience by enabling better file naming during downloads. Updated related methods and tests to accommodate this new functionality.
This commit is contained in:
Harry
2026-02-03 14:40:14 +08:00
parent 5441b9c3ad
commit 49befa6d3f
10 changed files with 240 additions and 63 deletions

View File

@ -1,5 +1,6 @@
import logging
from collections.abc import Generator
from urllib.parse import quote
import boto3
from botocore.client import BaseClient, Config
@ -105,23 +106,73 @@ class AwsS3Storage(BaseStorage):
def delete(self, filename):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Generate a presigned download URL.
Args:
filename: The S3 object key
expires_in: URL validity duration in seconds
download_filename: If provided, sets Content-Disposition header so browser
downloads the file with this name instead of the S3 key.
"""
params: dict = {"Bucket": self.bucket_name, "Key": filename}
if download_filename:
# RFC 5987 / RFC 6266: Use both filename and filename* for compatibility.
# filename* with UTF-8 encoding handles non-ASCII characters.
encoded = quote(download_filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
url: str = self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
Params=params,
ExpiresIn=expires_in,
)
return url
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
return [
self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Generate presigned download URLs for multiple files.
Args:
filenames: List of S3 object keys
expires_in: URL validity duration in seconds
download_filenames: If provided, must match len(filenames). Sets
Content-Disposition for each file.
"""
if download_filenames is None:
return [
self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
)
for filename in filenames
]
urls: list[str] = []
for filename, download_filename in zip(filenames, download_filenames, strict=True):
params: dict = {"Bucket": self.bucket_name, "Key": filename}
if download_filename:
encoded = quote(download_filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
urls.append(
self.client.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=expires_in,
)
)
for filename in filenames
]
return urls
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
url: str = self.client.generate_presigned_url(

View File

@ -39,7 +39,13 @@ class BaseStorage(ABC):
"""
raise NotImplementedError("This storage backend doesn't support scanning")
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""
Generate a pre-signed URL for downloading a file.
@ -49,6 +55,9 @@ class BaseStorage(ABC):
Args:
filename: The file path/key in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filename: If provided, the browser will use this as the downloaded
file name instead of the storage key. Implemented via response header
override (e.g., Content-Disposition) where supported.
Returns:
Pre-signed URL string
@ -58,9 +67,21 @@ class BaseStorage(ABC):
"""
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""
Generate pre-signed URLs for downloading multiple files.
Args:
filenames: List of file paths/keys in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filenames: If provided, must match len(filenames). Each element
specifies the download filename for the corresponding file.
"""
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")

View File

@ -1,6 +1,8 @@
"""Storage wrapper that caches presigned download URLs."""
import hashlib
import logging
from itertools import starmap
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
@ -39,23 +41,31 @@ class CachedPresignStorage(StorageWrapper):
super().delete(filename)
self.invalidate([filename])
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Get a presigned download URL, using cache when available.
Args:
filename: The file path/key in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filename: If provided, the browser will use this as the downloaded
file name. Cache keys include this value to avoid conflicts.
Returns:
Presigned URL string
"""
cache_key = self._cache_key(filename)
cache_key = self._cache_key(filename, download_filename)
cached = self._get_cached(cache_key)
if cached:
return cached
url = self._storage.get_download_url(filename, expires_in)
url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
self._set_cached(cache_key, url, expires_in)
return url
@ -64,12 +74,16 @@ class CachedPresignStorage(StorageWrapper):
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Batch get download URLs with cache.
Args:
filenames: List of file paths/keys in storage
expires_in: URL validity duration in seconds (default: 1 hour)
download_filenames: If provided, must match len(filenames). Each element
specifies the download filename for the corresponding file.
Returns:
List of presigned URLs in the same order as filenames
@ -77,22 +91,32 @@ class CachedPresignStorage(StorageWrapper):
if not filenames:
return []
cache_keys = [self._cache_key(f) for f in filenames]
# Build cache keys including download_filename for uniqueness
if download_filenames is None:
cache_keys = [self._cache_key(f, None) for f in filenames]
else:
cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True)))
cached_values = self._get_cached_batch(cache_keys)
# Build results list, tracking which indices need fetching
results: list[str | None] = list(cached_values)
uncached_indices: list[int] = []
uncached_filenames: list[str] = []
uncached_download_filenames: list[str | None] = []
for i, (filename, cached) in enumerate(zip(filenames, cached_values)):
if not cached:
uncached_indices.append(i)
uncached_filenames.append(filename)
uncached_download_filenames.append(download_filenames[i] if download_filenames else None)
# Batch fetch uncached URLs from storage
if uncached_filenames:
uncached_urls = [self._storage.get_download_url(f, expires_in) for f in uncached_filenames]
uncached_urls = [
self._storage.get_download_url(f, expires_in, download_filename=df)
for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True)
]
# Fill results at correct positions
for idx, url in zip(uncached_indices, uncached_urls):
@ -119,8 +143,19 @@ class CachedPresignStorage(StorageWrapper):
except Exception:
logger.warning("Failed to invalidate presign cache", exc_info=True)
def _cache_key(self, filename: str) -> str:
"""Generate cache key for a filename."""
def _cache_key(self, filename: str, download_filename: str | None = None) -> str:
"""Generate cache key for a filename.
When download_filename is provided, its hash is appended to the key to ensure
different download names for the same storage key get separate cache entries.
We use a hash (truncated MD5) instead of the raw string because:
- download_filename may contain special characters unsafe for Redis keys
- Hash collisions only cause a cache miss, no functional impact
"""
if download_filename:
# Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness
name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16]
return f"{self._cache_key_prefix}:{filename}::{name_hash}"
return f"{self._cache_key_prefix}:{filename}"
def _compute_ttl(self, expires_in: int) -> int:

View File

@ -28,23 +28,40 @@ class FilePresignStorage(StorageWrapper):
Otherwise, generates ticket-based URLs for both download and upload operations.
"""
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Get a presigned download URL, falling back to ticket URL if not supported."""
try:
return self._storage.get_download_url(filename, expires_in)
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
except NotImplementedError:
from services.storage_ticket_service import StorageTicketService
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename)
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Get presigned download URLs for multiple files."""
try:
return self._storage.get_download_urls(filenames, expires_in)
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
except NotImplementedError:
from services.storage_ticket_service import StorageTicketService
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
if download_filenames is None:
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
return [
StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df)
for f, df in zip(filenames, download_filenames, strict=True)
]
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
"""Get a presigned upload URL, falling back to ticket URL if not supported."""

View File

@ -44,11 +44,23 @@ class StorageWrapper(BaseStorage):
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
return self._storage.scan(path, files=files, directories=directories)
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
return self._storage.get_download_url(filename, expires_in)
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
return self._storage.get_download_urls(filenames, expires_in)
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
return self._storage.get_upload_url(filename, expires_in)