mirror of
https://github.com/langgenius/dify.git
synced 2026-05-03 00:48:04 +08:00
feat: enhance download URL generation with optional filename parameter
Added support for an optional `download_filename` parameter in the `get_download_url` and `get_download_urls` methods across various storage classes. This allows users to specify a custom filename for downloads, improving user experience by enabling better file naming during downloads. Updated related methods and tests to accommodate this new functionality.
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from urllib.parse import quote
|
||||
|
||||
import boto3
|
||||
from botocore.client import BaseClient, Config
|
||||
@ -105,23 +106,73 @@ class AwsS3Storage(BaseStorage):
|
||||
def delete(self, filename):
|
||||
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a presigned download URL.
|
||||
|
||||
Args:
|
||||
filename: The S3 object key
|
||||
expires_in: URL validity duration in seconds
|
||||
download_filename: If provided, sets Content-Disposition header so browser
|
||||
downloads the file with this name instead of the S3 key.
|
||||
"""
|
||||
params: dict = {"Bucket": self.bucket_name, "Key": filename}
|
||||
if download_filename:
|
||||
# RFC 5987 / RFC 6266: Use both filename and filename* for compatibility.
|
||||
# filename* with UTF-8 encoding handles non-ASCII characters.
|
||||
encoded = quote(download_filename)
|
||||
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
|
||||
url: str = self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
Params=params,
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
return url
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
return [
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
ExpiresIn=expires_in,
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Generate presigned download URLs for multiple files.
|
||||
|
||||
Args:
|
||||
filenames: List of S3 object keys
|
||||
expires_in: URL validity duration in seconds
|
||||
download_filenames: If provided, must match len(filenames). Sets
|
||||
Content-Disposition for each file.
|
||||
"""
|
||||
if download_filenames is None:
|
||||
return [
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params={"Bucket": self.bucket_name, "Key": filename},
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
for filename in filenames
|
||||
]
|
||||
|
||||
urls: list[str] = []
|
||||
for filename, download_filename in zip(filenames, download_filenames, strict=True):
|
||||
params: dict = {"Bucket": self.bucket_name, "Key": filename}
|
||||
if download_filename:
|
||||
encoded = quote(download_filename)
|
||||
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
|
||||
urls.append(
|
||||
self.client.generate_presigned_url(
|
||||
ClientMethod="get_object",
|
||||
Params=params,
|
||||
ExpiresIn=expires_in,
|
||||
)
|
||||
)
|
||||
for filename in filenames
|
||||
]
|
||||
return urls
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
url: str = self.client.generate_presigned_url(
|
||||
|
||||
@ -39,7 +39,13 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support scanning")
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate a pre-signed URL for downloading a file.
|
||||
|
||||
@ -49,6 +55,9 @@ class BaseStorage(ABC):
|
||||
Args:
|
||||
filename: The file path/key in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filename: If provided, the browser will use this as the downloaded
|
||||
file name instead of the storage key. Implemented via response header
|
||||
override (e.g., Content-Disposition) where supported.
|
||||
|
||||
Returns:
|
||||
Pre-signed URL string
|
||||
@ -58,9 +67,21 @@ class BaseStorage(ABC):
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate pre-signed URLs for downloading multiple files.
|
||||
|
||||
Args:
|
||||
filenames: List of file paths/keys in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filenames: If provided, must match len(filenames). Each element
|
||||
specifies the download filename for the corresponding file.
|
||||
"""
|
||||
raise NotImplementedError("This storage backend doesn't support pre-signed URLs")
|
||||
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
"""Storage wrapper that caches presigned download URLs."""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from itertools import starmap
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
@ -39,23 +41,31 @@ class CachedPresignStorage(StorageWrapper):
|
||||
super().delete(filename)
|
||||
self.invalidate([filename])
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get a presigned download URL, using cache when available.
|
||||
|
||||
Args:
|
||||
filename: The file path/key in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filename: If provided, the browser will use this as the downloaded
|
||||
file name. Cache keys include this value to avoid conflicts.
|
||||
|
||||
Returns:
|
||||
Presigned URL string
|
||||
"""
|
||||
cache_key = self._cache_key(filename)
|
||||
cache_key = self._cache_key(filename, download_filename)
|
||||
|
||||
cached = self._get_cached(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
url = self._storage.get_download_url(filename, expires_in)
|
||||
url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
self._set_cached(cache_key, url, expires_in)
|
||||
|
||||
return url
|
||||
@ -64,12 +74,16 @@ class CachedPresignStorage(StorageWrapper):
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Batch get download URLs with cache.
|
||||
|
||||
Args:
|
||||
filenames: List of file paths/keys in storage
|
||||
expires_in: URL validity duration in seconds (default: 1 hour)
|
||||
download_filenames: If provided, must match len(filenames). Each element
|
||||
specifies the download filename for the corresponding file.
|
||||
|
||||
Returns:
|
||||
List of presigned URLs in the same order as filenames
|
||||
@ -77,22 +91,32 @@ class CachedPresignStorage(StorageWrapper):
|
||||
if not filenames:
|
||||
return []
|
||||
|
||||
cache_keys = [self._cache_key(f) for f in filenames]
|
||||
# Build cache keys including download_filename for uniqueness
|
||||
if download_filenames is None:
|
||||
cache_keys = [self._cache_key(f, None) for f in filenames]
|
||||
else:
|
||||
cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True)))
|
||||
|
||||
cached_values = self._get_cached_batch(cache_keys)
|
||||
|
||||
# Build results list, tracking which indices need fetching
|
||||
results: list[str | None] = list(cached_values)
|
||||
uncached_indices: list[int] = []
|
||||
uncached_filenames: list[str] = []
|
||||
uncached_download_filenames: list[str | None] = []
|
||||
|
||||
for i, (filename, cached) in enumerate(zip(filenames, cached_values)):
|
||||
if not cached:
|
||||
uncached_indices.append(i)
|
||||
uncached_filenames.append(filename)
|
||||
uncached_download_filenames.append(download_filenames[i] if download_filenames else None)
|
||||
|
||||
# Batch fetch uncached URLs from storage
|
||||
if uncached_filenames:
|
||||
uncached_urls = [self._storage.get_download_url(f, expires_in) for f in uncached_filenames]
|
||||
uncached_urls = [
|
||||
self._storage.get_download_url(f, expires_in, download_filename=df)
|
||||
for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True)
|
||||
]
|
||||
|
||||
# Fill results at correct positions
|
||||
for idx, url in zip(uncached_indices, uncached_urls):
|
||||
@ -119,8 +143,19 @@ class CachedPresignStorage(StorageWrapper):
|
||||
except Exception:
|
||||
logger.warning("Failed to invalidate presign cache", exc_info=True)
|
||||
|
||||
def _cache_key(self, filename: str) -> str:
|
||||
"""Generate cache key for a filename."""
|
||||
def _cache_key(self, filename: str, download_filename: str | None = None) -> str:
|
||||
"""Generate cache key for a filename.
|
||||
|
||||
When download_filename is provided, its hash is appended to the key to ensure
|
||||
different download names for the same storage key get separate cache entries.
|
||||
We use a hash (truncated MD5) instead of the raw string because:
|
||||
- download_filename may contain special characters unsafe for Redis keys
|
||||
- Hash collisions only cause a cache miss, no functional impact
|
||||
"""
|
||||
if download_filename:
|
||||
# Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness
|
||||
name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16]
|
||||
return f"{self._cache_key_prefix}:{filename}::{name_hash}"
|
||||
return f"{self._cache_key_prefix}:{filename}"
|
||||
|
||||
def _compute_ttl(self, expires_in: int) -> int:
|
||||
|
||||
@ -28,23 +28,40 @@ class FilePresignStorage(StorageWrapper):
|
||||
Otherwise, generates ticket-based URLs for both download and upload operations.
|
||||
"""
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get a presigned download URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return self._storage.get_download_url(filename, expires_in)
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in)
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename)
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get presigned download URLs for multiple files."""
|
||||
try:
|
||||
return self._storage.get_download_urls(filenames, expires_in)
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
if download_filenames is None:
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
return [
|
||||
StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df)
|
||||
for f, df in zip(filenames, download_filenames, strict=True)
|
||||
]
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
"""Get a presigned upload URL, falling back to ticket URL if not supported."""
|
||||
|
||||
@ -44,11 +44,23 @@ class StorageWrapper(BaseStorage):
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
return self._storage.scan(path, files=files, directories=directories)
|
||||
|
||||
def get_download_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_download_url(filename, expires_in)
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
|
||||
def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]:
|
||||
return self._storage.get_download_urls(filenames, expires_in)
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_upload_url(filename, expires_in)
|
||||
|
||||
Reference in New Issue
Block a user