diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index b801b2b578..5ef73b6ad4 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -118,8 +118,14 @@ class Storage: def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: return self.storage_runner.scan(path, files=files, directories=directories) - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: - return self.storage_runner.get_download_url(filename, expires_in) + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + return self.storage_runner.get_download_url(filename, expires_in, download_filename=download_filename) storage = Storage() diff --git a/api/extensions/storage/aws_s3_storage.py b/api/extensions/storage/aws_s3_storage.py index 04d75a0f67..129e512bdf 100644 --- a/api/extensions/storage/aws_s3_storage.py +++ b/api/extensions/storage/aws_s3_storage.py @@ -1,5 +1,6 @@ import logging from collections.abc import Generator +from urllib.parse import quote import boto3 from botocore.client import BaseClient, Config @@ -105,23 +106,73 @@ class AwsS3Storage(BaseStorage): def delete(self, filename): self.client.delete_object(Bucket=self.bucket_name, Key=filename) - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + """Generate a presigned download URL. + + Args: + filename: The S3 object key + expires_in: URL validity duration in seconds + download_filename: If provided, sets Content-Disposition header so browser + downloads the file with this name instead of the S3 key. + """ + params: dict = {"Bucket": self.bucket_name, "Key": filename} + if download_filename: + # RFC 5987 / RFC 6266: Use both filename and filename* for compatibility. + # filename* with UTF-8 encoding handles non-ASCII characters. + encoded = quote(download_filename) + params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}" url: str = self.client.generate_presigned_url( ClientMethod="get_object", - Params={"Bucket": self.bucket_name, "Key": filename}, + Params=params, ExpiresIn=expires_in, ) return url - def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]: - return [ - self.client.generate_presigned_url( - ClientMethod="get_object", - Params={"Bucket": self.bucket_name, "Key": filename}, - ExpiresIn=expires_in, + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: + """Generate presigned download URLs for multiple files. + + Args: + filenames: List of S3 object keys + expires_in: URL validity duration in seconds + download_filenames: If provided, must match len(filenames). Sets + Content-Disposition for each file. + """ + if download_filenames is None: + return [ + self.client.generate_presigned_url( + ClientMethod="get_object", + Params={"Bucket": self.bucket_name, "Key": filename}, + ExpiresIn=expires_in, + ) + for filename in filenames + ] + + urls: list[str] = [] + for filename, download_filename in zip(filenames, download_filenames, strict=True): + params: dict = {"Bucket": self.bucket_name, "Key": filename} + if download_filename: + encoded = quote(download_filename) + params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}" + urls.append( + self.client.generate_presigned_url( + ClientMethod="get_object", + Params=params, + ExpiresIn=expires_in, + ) ) - for filename in filenames - ] + return urls def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: url: str = self.client.generate_presigned_url( diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index f72b984454..a58dfe8024 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -39,7 +39,13 @@ class BaseStorage(ABC): """ raise NotImplementedError("This storage backend doesn't support scanning") - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: """ Generate a pre-signed URL for downloading a file. @@ -49,6 +55,9 @@ class BaseStorage(ABC): Args: filename: The file path/key in storage expires_in: URL validity duration in seconds (default: 1 hour) + download_filename: If provided, the browser will use this as the downloaded + file name instead of the storage key. Implemented via response header + override (e.g., Content-Disposition) where supported. Returns: Pre-signed URL string @@ -58,9 +67,21 @@ class BaseStorage(ABC): """ raise NotImplementedError("This storage backend doesn't support pre-signed URLs") - def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]: + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: """ Generate pre-signed URLs for downloading multiple files. + + Args: + filenames: List of file paths/keys in storage + expires_in: URL validity duration in seconds (default: 1 hour) + download_filenames: If provided, must match len(filenames). Each element + specifies the download filename for the corresponding file. """ raise NotImplementedError("This storage backend doesn't support pre-signed URLs") diff --git a/api/extensions/storage/cached_presign_storage.py b/api/extensions/storage/cached_presign_storage.py index d4a6bcc79a..079a43aa70 100644 --- a/api/extensions/storage/cached_presign_storage.py +++ b/api/extensions/storage/cached_presign_storage.py @@ -1,6 +1,8 @@ """Storage wrapper that caches presigned download URLs.""" +import hashlib import logging +from itertools import starmap from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage @@ -39,23 +41,31 @@ class CachedPresignStorage(StorageWrapper): super().delete(filename) self.invalidate([filename]) - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: """Get a presigned download URL, using cache when available. Args: filename: The file path/key in storage expires_in: URL validity duration in seconds (default: 1 hour) + download_filename: If provided, the browser will use this as the downloaded + file name. Cache keys include this value to avoid conflicts. Returns: Presigned URL string """ - cache_key = self._cache_key(filename) + cache_key = self._cache_key(filename, download_filename) cached = self._get_cached(cache_key) if cached: return cached - url = self._storage.get_download_url(filename, expires_in) + url = self._storage.get_download_url(filename, expires_in, download_filename=download_filename) self._set_cached(cache_key, url, expires_in) return url @@ -64,12 +74,16 @@ class CachedPresignStorage(StorageWrapper): self, filenames: list[str], expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, ) -> list[str]: """Batch get download URLs with cache. Args: filenames: List of file paths/keys in storage expires_in: URL validity duration in seconds (default: 1 hour) + download_filenames: If provided, must match len(filenames). Each element + specifies the download filename for the corresponding file. Returns: List of presigned URLs in the same order as filenames @@ -77,22 +91,32 @@ class CachedPresignStorage(StorageWrapper): if not filenames: return [] - cache_keys = [self._cache_key(f) for f in filenames] + # Build cache keys including download_filename for uniqueness + if download_filenames is None: + cache_keys = [self._cache_key(f, None) for f in filenames] + else: + cache_keys = list(starmap(self._cache_key, zip(filenames, download_filenames, strict=True))) + cached_values = self._get_cached_batch(cache_keys) # Build results list, tracking which indices need fetching results: list[str | None] = list(cached_values) uncached_indices: list[int] = [] uncached_filenames: list[str] = [] + uncached_download_filenames: list[str | None] = [] for i, (filename, cached) in enumerate(zip(filenames, cached_values)): if not cached: uncached_indices.append(i) uncached_filenames.append(filename) + uncached_download_filenames.append(download_filenames[i] if download_filenames else None) # Batch fetch uncached URLs from storage if uncached_filenames: - uncached_urls = [self._storage.get_download_url(f, expires_in) for f in uncached_filenames] + uncached_urls = [ + self._storage.get_download_url(f, expires_in, download_filename=df) + for f, df in zip(uncached_filenames, uncached_download_filenames, strict=True) + ] # Fill results at correct positions for idx, url in zip(uncached_indices, uncached_urls): @@ -119,8 +143,19 @@ class CachedPresignStorage(StorageWrapper): except Exception: logger.warning("Failed to invalidate presign cache", exc_info=True) - def _cache_key(self, filename: str) -> str: - """Generate cache key for a filename.""" + def _cache_key(self, filename: str, download_filename: str | None = None) -> str: + """Generate cache key for a filename. + + When download_filename is provided, its hash is appended to the key to ensure + different download names for the same storage key get separate cache entries. + We use a hash (truncated MD5) instead of the raw string because: + - download_filename may contain special characters unsafe for Redis keys + - Hash collisions only cause a cache miss, no functional impact + """ + if download_filename: + # Use first 16 chars of MD5 hex digest (64 bits) - sufficient for cache key uniqueness + name_hash = hashlib.md5(download_filename.encode("utf-8")).hexdigest()[:16] + return f"{self._cache_key_prefix}:{filename}::{name_hash}" return f"{self._cache_key_prefix}:{filename}" def _compute_ttl(self, expires_in: int) -> int: diff --git a/api/extensions/storage/file_presign_storage.py b/api/extensions/storage/file_presign_storage.py index 27cc2ea5e2..f43e00b57a 100644 --- a/api/extensions/storage/file_presign_storage.py +++ b/api/extensions/storage/file_presign_storage.py @@ -28,23 +28,40 @@ class FilePresignStorage(StorageWrapper): Otherwise, generates ticket-based URLs for both download and upload operations. """ - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: """Get a presigned download URL, falling back to ticket URL if not supported.""" try: - return self._storage.get_download_url(filename, expires_in) + return self._storage.get_download_url(filename, expires_in, download_filename=download_filename) except NotImplementedError: from services.storage_ticket_service import StorageTicketService - return StorageTicketService.create_download_url(filename, expires_in=expires_in) + return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename) - def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]: + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: """Get presigned download URLs for multiple files.""" try: - return self._storage.get_download_urls(filenames, expires_in) + return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames) except NotImplementedError: from services.storage_ticket_service import StorageTicketService - return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames] + if download_filenames is None: + return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames] + return [ + StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df) + for f, df in zip(filenames, download_filenames, strict=True) + ] def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: """Get a presigned upload URL, falling back to ticket URL if not supported.""" diff --git a/api/extensions/storage/storage_wrapper.py b/api/extensions/storage/storage_wrapper.py index d3ed8ea317..db472f3c47 100644 --- a/api/extensions/storage/storage_wrapper.py +++ b/api/extensions/storage/storage_wrapper.py @@ -44,11 +44,23 @@ class StorageWrapper(BaseStorage): def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: return self._storage.scan(path, files=files, directories=directories) - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: - return self._storage.get_download_url(filename, expires_in) + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: + return self._storage.get_download_url(filename, expires_in, download_filename=download_filename) - def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]: - return self._storage.get_download_urls(filenames, expires_in) + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: + return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames) def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: return self._storage.get_upload_url(filename, expires_in) diff --git a/api/services/app_asset_service.py b/api/services/app_asset_service.py index 215a8aef7a..35daafe0ae 100644 --- a/api/services/app_asset_service.py +++ b/api/services/app_asset_service.py @@ -366,7 +366,7 @@ class AppAssetService: asset_storage = AppAssetService.get_storage() key = AssetPaths.draft(app_model.tenant_id, app_model.id, node_id) - return asset_storage.get_download_url(key, expires_in) + return asset_storage.get_download_url(key, expires_in, download_filename=node.name) @staticmethod def get_source_zip_bytes(tenant_id: str, app_id: str, workflow_id: str) -> bytes | None: diff --git a/api/services/app_bundle_service.py b/api/services/app_bundle_service.py index 56aa1710f0..3e54006607 100644 --- a/api/services/app_bundle_service.py +++ b/api/services/app_bundle_service.py @@ -145,8 +145,9 @@ class AppBundleService: archive = zs.zip(src="bundle_root", include_base=False) zs.upload(archive, upload_url) - download_url = asset_storage.get_download_url(export_key, expires_in) - return BundleExportResult(download_url=download_url, filename=f"{safe_name}.zip") + bundle_filename = f"{safe_name}.zip" + download_url = asset_storage.get_download_url(export_key, expires_in, download_filename=bundle_filename) + return BundleExportResult(download_url=download_url, filename=bundle_filename) # ========== Import ========== diff --git a/api/tests/unit_tests/core/app_assets/test_storage.py b/api/tests/unit_tests/core/app_assets/test_storage.py index d97e40fad1..35b3c35c81 100644 --- a/api/tests/unit_tests/core/app_assets/test_storage.py +++ b/api/tests/unit_tests/core/app_assets/test_storage.py @@ -32,10 +32,22 @@ class DummyStorage(BaseStorage): def delete(self, filename: str): return None - def get_download_url(self, filename: str, expires_in: int = 3600) -> str: + def get_download_url( + self, + filename: str, + expires_in: int = 3600, + *, + download_filename: str | None = None, + ) -> str: raise NotImplementedError - def get_download_urls(self, filenames: list[str], expires_in: int = 3600) -> list[str]: + def get_download_urls( + self, + filenames: list[str], + expires_in: int = 3600, + *, + download_filenames: list[str] | None = None, + ) -> list[str]: raise NotImplementedError def get_upload_url(self, filename: str, expires_in: int = 3600) -> str: diff --git a/api/tests/unit_tests/extensions/storage/test_cached_presign_storage.py b/api/tests/unit_tests/extensions/storage/test_cached_presign_storage.py index 3626f8e38c..c2eb26c80b 100644 --- a/api/tests/unit_tests/extensions/storage/test_cached_presign_storage.py +++ b/api/tests/unit_tests/extensions/storage/test_cached_presign_storage.py @@ -1,4 +1,5 @@ -from unittest.mock import Mock +import hashlib +from unittest.mock import Mock, patch import pytest @@ -19,12 +20,16 @@ class TestCachedPresignStorage: return Mock() @pytest.fixture - def cached_storage(self, mock_storage): + def cached_storage(self, mock_storage, mock_redis): """Create CachedPresignStorage with mocks.""" - return CachedPresignStorage( - storage=mock_storage, - cache_key_prefix="test_prefix", - ) + with patch("extensions.storage.cached_presign_storage.redis_client", mock_redis): + storage = CachedPresignStorage( + storage=mock_storage, + cache_key_prefix="test_prefix", + ) + # Inject mock_redis after creation for tests to verify calls + storage._redis = mock_redis + yield storage def test_get_download_url_returns_cached_on_hit(self, cached_storage, mock_storage, mock_redis): """Test that cached URL is returned when cache hit occurs.""" @@ -46,7 +51,7 @@ class TestCachedPresignStorage: assert result == "https://new-url.com/file.txt" mock_redis.mget.assert_called_once_with(["test_prefix:path/to/file.txt"]) - mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600) + mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None) mock_redis.setex.assert_called_once() call_args = mock_redis.setex.call_args assert call_args[0][0] == "test_prefix:path/to/file.txt" @@ -61,7 +66,7 @@ class TestCachedPresignStorage: result = cached_storage.get_download_urls(filenames, expires_in=3600) assert result == ["https://cached1.com", "https://new.com", "https://cached2.com"] - mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600) + mock_storage.get_download_url.assert_called_once_with("file2.txt", 3600, download_filename=None) # Verify pipeline was used for batch cache write mock_redis.pipeline.assert_called_once() mock_redis.pipeline().execute.assert_called_once() @@ -114,7 +119,7 @@ class TestCachedPresignStorage: result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600) assert result == "https://new-url.com/file.txt" - mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600) + mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename=None) def test_graceful_degradation_on_redis_setex_error(self, cached_storage, mock_storage, mock_redis): """Test that URL is still returned when Redis setex fails.""" @@ -177,28 +182,45 @@ class TestCachedPresignStorage: key = cached_storage._cache_key("path/to/file.txt") assert key == "test_prefix:path/to/file.txt" - def test_cached_value_decoded_from_bytes(self, cached_storage, mock_storage, mock_redis): - """Test that bytes cached values are decoded to strings.""" - mock_redis.mget.return_value = [b"https://cached-url.com"] + def test_cache_key_with_download_filename(self, cached_storage): + """Test cache key includes hashed download_filename when provided.""" + key = cached_storage._cache_key("path/to/file.txt", "custom_name.txt") + # download_filename is hashed (first 16 chars of MD5 hex digest) + expected_hash = hashlib.md5(b"custom_name.txt").hexdigest()[:16] + assert key == f"test_prefix:path/to/file.txt::{expected_hash}" - result = cached_storage.get_download_url("file.txt") + def test_get_download_url_with_download_filename(self, cached_storage, mock_storage, mock_redis): + """Test that download_filename is passed to storage and affects cache key.""" + mock_redis.mget.return_value = [None] + mock_storage.get_download_url.return_value = "https://new-url.com/file.txt" - assert result == "https://cached-url.com" - assert isinstance(result, str) + result = cached_storage.get_download_url("path/to/file.txt", expires_in=3600, download_filename="custom.txt") - def test_cached_value_decoded_from_bytearray(self, cached_storage, mock_storage, mock_redis): - """Test that bytearray cached values are decoded to strings.""" - mock_redis.mget.return_value = [bytearray(b"https://cached-url.com")] + assert result == "https://new-url.com/file.txt" + expected_hash = hashlib.md5(b"custom.txt").hexdigest()[:16] + mock_redis.mget.assert_called_once_with([f"test_prefix:path/to/file.txt::{expected_hash}"]) + mock_storage.get_download_url.assert_called_once_with("path/to/file.txt", 3600, download_filename="custom.txt") - result = cached_storage.get_download_url("file.txt") + def test_get_download_urls_with_download_filenames(self, cached_storage, mock_storage, mock_redis): + """Test batch URL retrieval with download_filenames.""" + mock_redis.mget.return_value = [None, None] + mock_storage.get_download_url.side_effect = ["https://url1.com", "https://url2.com"] - assert result == "https://cached-url.com" - assert isinstance(result, str) + filenames = ["file1.txt", "file2.txt"] + download_filenames = ["custom1.txt", "custom2.txt"] + result = cached_storage.get_download_urls(filenames, expires_in=3600, download_filenames=download_filenames) - def test_default_cache_key_prefix(self, mock_storage): - """Test default cache key prefix is used when not specified.""" - storage = CachedPresignStorage( - storage=mock_storage, + assert result == ["https://url1.com", "https://url2.com"] + # Verify cache keys include hashed download_filenames + hash1 = hashlib.md5(b"custom1.txt").hexdigest()[:16] + hash2 = hashlib.md5(b"custom2.txt").hexdigest()[:16] + mock_redis.mget.assert_called_once_with( + [ + f"test_prefix:file1.txt::{hash1}", + f"test_prefix:file2.txt::{hash2}", + ] ) - key = storage._cache_key("file.txt") - assert key == "presign_cache:file.txt" + # Verify storage calls include download_filename + assert mock_storage.get_download_url.call_count == 2 + mock_storage.get_download_url.assert_any_call("file1.txt", 3600, download_filename="custom1.txt") + mock_storage.get_download_url.assert_any_call("file2.txt", 3600, download_filename="custom2.txt")