Files
dify/api/extensions/storage/aws_s3_storage.py
Novice cccff6768a Merge commit '9c339239' into sandboxed-agent-rebase
Made-with: Cursor

# Conflicts:
#	api/README.md
#	api/controllers/console/app/workflow_draft_variable.py
#	api/core/agent/cot_agent_runner.py
#	api/core/agent/fc_agent_runner.py
#	api/core/app/apps/advanced_chat/app_runner.py
#	api/core/plugin/backwards_invocation/model.py
#	api/core/prompt/advanced_prompt_transform.py
#	api/core/workflow/nodes/base/node.py
#	api/core/workflow/nodes/llm/llm_utils.py
#	api/core/workflow/nodes/llm/node.py
#	api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
#	api/core/workflow/nodes/question_classifier/question_classifier_node.py
#	api/core/workflow/runtime/graph_runtime_state.py
#	api/extensions/storage/base_storage.py
#	api/factories/variable_factory.py
#	api/pyproject.toml
#	api/services/variable_truncator.py
#	api/uv.lock
#	web/app/account/oauth/authorize/page.tsx
#	web/app/components/app/configuration/config-var/config-modal/field.tsx
#	web/app/components/base/alert.tsx
#	web/app/components/base/chat/chat/answer/human-input-content/executed-action.tsx
#	web/app/components/base/chat/chat/answer/more.tsx
#	web/app/components/base/chat/chat/answer/operation.tsx
#	web/app/components/base/chat/chat/answer/workflow-process.tsx
#	web/app/components/base/chat/chat/citation/index.tsx
#	web/app/components/base/chat/chat/citation/popup.tsx
#	web/app/components/base/chat/chat/citation/progress-tooltip.tsx
#	web/app/components/base/chat/chat/citation/tooltip.tsx
#	web/app/components/base/chat/chat/question.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/index.tsx
#	web/app/components/base/chat/embedded-chatbot/inputs-form/view-form-dropdown.tsx
#	web/app/components/base/markdown-blocks/form.tsx
#	web/app/components/base/prompt-editor/plugins/hitl-input-block/component-ui.tsx
#	web/app/components/base/tag-management/panel.tsx
#	web/app/components/base/tag-management/trigger.tsx
#	web/app/components/header/account-setting/index.tsx
#	web/app/components/header/account-setting/members-page/transfer-ownership-modal/index.tsx
#	web/app/components/header/account-setting/model-provider-page/provider-added-card/index.tsx
#	web/app/signin/utils/post-login-redirect.ts
#	web/eslint-suppressions.json
#	web/package.json
#	web/pnpm-lock.yaml
2026-03-23 09:00:45 +08:00

183 lines
6.8 KiB
Python

import logging
from collections.abc import Generator
from urllib.parse import quote
import boto3
from botocore.client import BaseClient, Config
from botocore.exceptions import ClientError
from configs import dify_config
from extensions.storage.base_storage import BaseStorage
logger = logging.getLogger(__name__)
class AwsS3Storage(BaseStorage):
"""Implementation for Amazon Web Services S3 storage."""
client: BaseClient
def __init__(self):
super().__init__()
self.bucket_name = dify_config.S3_BUCKET_NAME
# NOTE:
# Some S3-compatible providers (e.g. MinIO) are strict about presigned request
# signature calculation. If the client sends a `Content-Type` header while the
# server generates a SigV2 presigned URL (or otherwise signs an empty Content-Type),
# the request may be rejected with 403 (signature mismatch).
#
# Enforce SigV4 for presigned URLs to avoid accidental SigV2 fallback and to keep
# header signing behavior consistent across providers.
s3_client_config = Config(
signature_version="s3v4",
s3={"addressing_style": dify_config.S3_ADDRESS_STYLE},
)
if dify_config.S3_USE_AWS_MANAGED_IAM:
logger.info("Using AWS managed IAM role for S3")
session = boto3.Session()
region_name = dify_config.S3_REGION
self.client = session.client(
service_name="s3",
region_name=region_name,
config=s3_client_config,
)
else:
logger.info("Using ak and sk for S3")
self.client = boto3.client(
"s3",
aws_secret_access_key=dify_config.S3_SECRET_KEY,
aws_access_key_id=dify_config.S3_ACCESS_KEY,
endpoint_url=dify_config.S3_ENDPOINT,
region_name=dify_config.S3_REGION,
config=s3_client_config,
)
# create bucket
try:
self.client.head_bucket(Bucket=self.bucket_name)
except ClientError as e:
# if bucket not exists, create it
if e.response.get("Error", {}).get("Code") == "404":
self.client.create_bucket(Bucket=self.bucket_name)
# if bucket is not accessible, pass, maybe the bucket is existing but not accessible
elif e.response.get("Error", {}).get("Code") == "403":
pass
else:
# other error, raise exception
raise
def save(self, filename, data):
self.client.put_object(Bucket=self.bucket_name, Key=filename, Body=data)
def load_once(self, filename: str) -> bytes:
try:
data: bytes = self.client.get_object(Bucket=self.bucket_name, Key=filename)["Body"].read()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("File not found")
else:
raise
return data
def load_stream(self, filename: str) -> Generator:
try:
response = self.client.get_object(Bucket=self.bucket_name, Key=filename)
yield from response["Body"].iter_chunks()
except ClientError as ex:
if ex.response.get("Error", {}).get("Code") == "NoSuchKey":
raise FileNotFoundError("file not found")
elif "reached max retries" in str(ex):
raise ValueError("please do not request the same file too frequently")
else:
raise
def download(self, filename, target_filepath):
self.client.download_file(self.bucket_name, filename, target_filepath)
def exists(self, filename):
try:
self.client.head_object(Bucket=self.bucket_name, Key=filename)
return True
except:
return False
def delete(self, filename: str):
self.client.delete_object(Bucket=self.bucket_name, Key=filename)
def get_download_url(
self,
filename: str,
expires_in: int = 3600,
*,
download_filename: str | None = None,
) -> str:
"""Generate a presigned download URL.
Args:
filename: The S3 object key
expires_in: URL validity duration in seconds
download_filename: If provided, sets Content-Disposition header so browser
downloads the file with this name instead of the S3 key.
"""
params: dict = {"Bucket": self.bucket_name, "Key": filename}
# RFC 5987 / RFC 6266: Use both filename and filename* for compatibility.
# filename* with UTF-8 encoding handles non-ASCII characters.
encoded = quote(download_filename or filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
url: str = self.client.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=expires_in,
)
return url
def get_download_urls(
self,
filenames: list[str],
expires_in: int = 3600,
*,
download_filenames: list[str] | None = None,
) -> list[str]:
"""Generate presigned download URLs for multiple files.
Args:
filenames: List of S3 object keys
expires_in: URL validity duration in seconds
download_filenames: If provided, must match len(filenames). Sets
Content-Disposition for each file.
"""
if download_filenames is None:
return [
self.client.generate_presigned_url(
ClientMethod="get_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
)
for filename in filenames
]
urls: list[str] = []
for filename, download_filename in zip(filenames, download_filenames, strict=True):
params: dict = {"Bucket": self.bucket_name, "Key": filename}
if download_filename:
encoded = quote(download_filename)
params["ResponseContentDisposition"] = f"attachment; filename=\"{encoded}\"; filename*=UTF-8''{encoded}"
urls.append(
self.client.generate_presigned_url(
ClientMethod="get_object",
Params=params,
ExpiresIn=expires_in,
)
)
return urls
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
url: str = self.client.generate_presigned_url(
ClientMethod="put_object",
Params={"Bucket": self.bucket_name, "Key": filename},
ExpiresIn=expires_in,
)
return url