Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@ -7,8 +7,8 @@ eliminates the need for repetitive language switching logic.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Any, Optional, Protocol
from enum import StrEnum, auto
from typing import Any, Protocol
from flask import render_template
from pydantic import BaseModel, Field
@ -17,26 +17,30 @@ from extensions.ext_mail import mail
from services.feature_service import BrandingModel, FeatureService
class EmailType(Enum):
class EmailType(StrEnum):
"""Enumeration of supported email types."""
RESET_PASSWORD = "reset_password"
INVITE_MEMBER = "invite_member"
EMAIL_CODE_LOGIN = "email_code_login"
CHANGE_EMAIL_OLD = "change_email_old"
CHANGE_EMAIL_NEW = "change_email_new"
CHANGE_EMAIL_COMPLETED = "change_email_completed"
OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
ENTERPRISE_CUSTOM = "enterprise_custom"
QUEUE_MONITOR_ALERT = "queue_monitor_alert"
DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
RESET_PASSWORD = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST = auto()
INVITE_MEMBER = auto()
EMAIL_CODE_LOGIN = auto()
CHANGE_EMAIL_OLD = auto()
CHANGE_EMAIL_NEW = auto()
CHANGE_EMAIL_COMPLETED = auto()
OWNER_TRANSFER_CONFIRM = auto()
OWNER_TRANSFER_OLD_NOTIFY = auto()
OWNER_TRANSFER_NEW_NOTIFY = auto()
ACCOUNT_DELETION_SUCCESS = auto()
ACCOUNT_DELETION_VERIFICATION = auto()
ENTERPRISE_CUSTOM = auto()
QUEUE_MONITOR_ALERT = auto()
DOCUMENT_CLEAN_NOTIFY = auto()
EMAIL_REGISTER = auto()
EMAIL_REGISTER_WHEN_ACCOUNT_EXIST = auto()
RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER = auto()
class EmailLanguage(Enum):
class EmailLanguage(StrEnum):
"""Supported email languages with fallback handling."""
EN_US = "en-US"
@ -128,7 +132,7 @@ class FeatureBrandingService:
class EmailSender(Protocol):
"""Protocol for email sending abstraction."""
def send_email(self, to: str, subject: str, html_content: str) -> None:
def send_email(self, to: str, subject: str, html_content: str):
"""Send email with given parameters."""
...
@ -136,7 +140,7 @@ class EmailSender(Protocol):
class FlaskMailSender:
"""Flask-Mail based email sender."""
def send_email(self, to: str, subject: str, html_content: str) -> None:
def send_email(self, to: str, subject: str, html_content: str):
"""Send email using Flask-Mail."""
if mail.is_inited():
mail.send(to=to, subject=subject, html=html_content)
@ -156,7 +160,7 @@ class EmailI18nService:
renderer: EmailRenderer,
branding_service: BrandingService,
sender: EmailSender,
) -> None:
):
self._config = config
self._renderer = renderer
self._branding_service = branding_service
@ -167,8 +171,8 @@ class EmailI18nService:
email_type: EmailType,
language_code: str,
to: str,
template_context: Optional[dict[str, Any]] = None,
) -> None:
template_context: dict[str, Any] | None = None,
):
"""
Send internationalized email with branding support.
@ -192,7 +196,7 @@ class EmailI18nService:
to: str,
code: str,
phase: str,
) -> None:
):
"""
Send change email notification with phase-specific handling.
@ -224,7 +228,7 @@ class EmailI18nService:
to: str | list[str],
subject: str,
html_content: str,
) -> None:
):
"""
Send a raw email directly without template processing.
@ -441,6 +445,54 @@ def create_default_email_config() -> EmailI18nConfig:
branded_template_path="clean_document_job_mail_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_template_en-US.html",
branded_template_path="without-brand/register_email_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_template_zh-CN.html",
branded_template_path="without-brand/register_email_template_zh-CN.html",
),
},
EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Register Your {application_title} Account",
template_path="register_email_when_account_exist_template_en-US.html",
branded_template_path="without-brand/register_email_when_account_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="注册您的 {application_title} 账户",
template_path="register_email_when_account_exist_template_zh-CN.html",
branded_template_path="without-brand/register_email_when_account_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_template_zh-CN.html",
),
},
EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER: {
EmailLanguage.EN_US: EmailTemplate(
subject="Reset Your {application_title} Password",
template_path="reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_en-US.html",
),
EmailLanguage.ZH_HANS: EmailTemplate(
subject="重置您的 {application_title} 密码",
template_path="reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
branded_template_path="without-brand/reset_password_mail_when_account_not_exist_no_register_template_zh-CN.html",
),
},
}
return EmailI18nConfig(templates=templates)
@ -463,7 +515,7 @@ def get_default_email_i18n_service() -> EmailI18nService:
# Global instance
_email_i18n_service: Optional[EmailI18nService] = None
_email_i18n_service: EmailI18nService | None = None
def get_email_i18n_service() -> EmailI18nService:

View File

@ -1,11 +1,9 @@
from typing import Optional
from werkzeug.exceptions import HTTPException
class BaseHTTPException(HTTPException):
error_code: str = "unknown"
data: Optional[dict] = None
data: dict | None = None
def __init__(self, description=None, response=None):
super().__init__(description, response)

View File

@ -16,7 +16,7 @@ def http_status_message(code):
return HTTP_STATUS_CODES.get(code, "")
def register_external_error_handlers(api: Api) -> None:
def register_external_error_handlers(api: Api):
@api.errorhandler(HTTPException)
def handle_http_exception(e: HTTPException):
got_request_exception.send(current_app, exception=e)
@ -69,6 +69,8 @@ def register_external_error_handlers(api: Api) -> None:
headers["WWW-Authenticate"] = 'Bearer realm="api"'
return data, status_code, headers
_ = handle_http_exception
@api.errorhandler(ValueError)
def handle_value_error(e: ValueError):
got_request_exception.send(current_app, exception=e)
@ -76,6 +78,8 @@ def register_external_error_handlers(api: Api) -> None:
data = {"code": "invalid_param", "message": str(e), "status": status_code}
return data, status_code
_ = handle_value_error
@api.errorhandler(AppInvokeQuotaExceededError)
def handle_quota_exceeded(e: AppInvokeQuotaExceededError):
got_request_exception.send(current_app, exception=e)
@ -83,15 +87,17 @@ def register_external_error_handlers(api: Api) -> None:
data = {"code": "too_many_requests", "message": str(e), "status": status_code}
return data, status_code
_ = handle_quota_exceeded
@api.errorhandler(Exception)
def handle_general_exception(e: Exception):
got_request_exception.send(current_app, exception=e)
status_code = 500
data: dict[str, Any] = getattr(e, "data", {"message": http_status_message(status_code)})
data = getattr(e, "data", {"message": http_status_message(status_code)})
# 🔒 Normalize non-mapping data (e.g., if someone set e.data = Response)
if not isinstance(data, Mapping):
if not isinstance(data, dict):
data = {"message": str(e)}
data.setdefault("code", "unknown")
@ -105,6 +111,8 @@ def register_external_error_handlers(api: Api) -> None:
return data, status_code
_ = handle_general_exception
class ExternalApi(Api):
_authorizations = {

View File

@ -3,7 +3,7 @@ from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from flask import Flask, g, has_request_context
from flask import Flask, g
T = TypeVar("T")
@ -48,7 +48,8 @@ def preserve_flask_contexts(
# Save current user before entering new app context
saved_user = None
if has_request_context() and hasattr(g, "_login_user"):
# Check for user in g (works in both request context and app context)
if hasattr(g, "_login_user"):
saved_user = g._login_user
# Enter Flask app context

View File

@ -136,7 +136,7 @@ class PKCS1OAepCipher:
# Step 3a (OS2IP)
em_int = bytes_to_long(em)
# Step 3b (RSAEP)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n)
m_int = gmpy2.powmod(em_int, self._key.e, self._key.n) # ty: ignore [unresolved-attribute]
# Step 3c (I2OSP)
c = long_to_bytes(m_int, k)
return c
@ -169,7 +169,7 @@ class PKCS1OAepCipher:
ct_int = bytes_to_long(ciphertext)
# Step 2b (RSADP)
# m_int = self._key._decrypt(ct_int)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n)
m_int = gmpy2.powmod(ct_int, self._key.d, self._key.n) # ty: ignore [unresolved-attribute]
# Complete step 2c (I2OSP)
em = long_to_bytes(m_int, k)
# Step 3a

View File

@ -68,7 +68,7 @@ class AppIconUrlField(fields.Raw):
if isinstance(obj, dict) and "app" in obj:
obj = obj["app"]
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value:
if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE:
return file_helpers.get_signed_file_url(obj.icon)
return None
@ -167,13 +167,6 @@ class DatetimeString:
return value
def _get_float(value):
try:
return float(value)
except (TypeError, ValueError):
raise ValueError(f"{value} is not a valid float")
def timezone(timezone_string):
if timezone_string and timezone_string in available_timezones():
return timezone_string
@ -185,7 +178,7 @@ def timezone(timezone_string):
def generate_string(n):
letters_digits = string.ascii_letters + string.digits
result = ""
for i in range(n):
for _ in range(n):
result += secrets.choice(letters_digits)
return result
@ -276,8 +269,8 @@ class TokenManager:
cls,
token_type: str,
account: Optional["Account"] = None,
email: Optional[str] = None,
additional_data: Optional[dict] = None,
email: str | None = None,
additional_data: dict | None = None,
) -> str:
if account is None and email is None:
raise ValueError("Account or email must be provided")
@ -319,19 +312,19 @@ class TokenManager:
redis_client.delete(token_key)
@classmethod
def get_token_data(cls, token: str, token_type: str) -> Optional[dict[str, Any]]:
def get_token_data(cls, token: str, token_type: str) -> dict[str, Any] | None:
key = cls._get_token_key(token, token_type)
token_data_json = redis_client.get(key)
if token_data_json is None:
logger.warning("%s token %s not found with key %s", token_type, token, key)
return None
token_data: Optional[dict[str, Any]] = json.loads(token_data_json)
token_data: dict[str, Any] | None = json.loads(token_data_json)
return token_data
@classmethod
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> Optional[str]:
def _get_current_token_for_account(cls, account_id: str, token_type: str) -> str | None:
key = cls._get_account_token_key(account_id, token_type)
current_token: Optional[str] = redis_client.get(key)
current_token: str | None = redis_client.get(key)
return current_token
@classmethod

View File

@ -3,7 +3,7 @@ import json
from core.llm_generator.output_parser.errors import OutputParserError
def parse_json_markdown(json_string: str) -> dict:
def parse_json_markdown(json_string: str):
# Get json from the backticks/braces
json_string = json_string.strip()
starts = ["```json", "```", "``", "`", "{"]
@ -33,7 +33,7 @@ def parse_json_markdown(json_string: str) -> dict:
return parsed
def parse_and_check_json_markdown(text: str, expected_keys: list[str]) -> dict:
def parse_and_check_json_markdown(text: str, expected_keys: list[str]):
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from functools import wraps
from typing import Union, cast
@ -12,9 +13,13 @@ from models.model import EndUser
#: A proxy for the current user. If no user is logged in, this will be an
#: anonymous user
current_user = cast(Union[Account, EndUser, None], LocalProxy(lambda: _get_user()))
from typing import ParamSpec, TypeVar
P = ParamSpec("P")
R = TypeVar("R")
def login_required(func):
def login_required(func: Callable[P, R]):
"""
If you decorate a view with this, it will ensure that the current user is
logged in and authenticated before calling the actual view. (If they are
@ -49,17 +54,12 @@ def login_required(func):
"""
@wraps(func)
def decorated_view(*args, **kwargs):
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# flask 1.x compatibility
# current_app.ensure_sync is only available in Flask >= 2.0
if callable(getattr(current_app, "ensure_sync", None)):
return current_app.ensure_sync(func)(*args, **kwargs)
return func(*args, **kwargs)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view

View File

@ -7,10 +7,9 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
import sys
from importlib import import_module
from typing import Any
def cached_import(module_path: str, class_name: str) -> Any:
def cached_import(module_path: str, class_name: str):
"""
Import a module and return the named attribute/class from it, with caching.
@ -30,7 +29,7 @@ def cached_import(module_path: str, class_name: str) -> Any:
return getattr(module, class_name)
def import_string(dotted_path: str) -> Any:
def import_string(dotted_path: str):
"""
Import a dotted module path and return the attribute/class designated by
the last name in the path. Raise ImportError if the import failed.

View File

@ -1,8 +1,7 @@
import urllib.parse
from dataclasses import dataclass
from typing import Optional
import requests
import httpx
@dataclass
@ -41,7 +40,7 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: Optional[str] = None):
def get_authorization_url(self, invite_token: str | None = None):
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@ -59,7 +58,7 @@ class GitHubOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -71,11 +70,11 @@ class GitHubOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"token {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
email_response = requests.get(self._EMAIL_INFO_URL, headers=headers)
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
@ -93,7 +92,7 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: Optional[str] = None):
def get_authorization_url(self, invite_token: str | None = None):
params = {
"client_id": self.client_id,
"response_type": "code",
@ -113,7 +112,7 @@ class GoogleOAuth(OAuth):
"redirect_uri": self.redirect_uri,
}
headers = {"Accept": "application/json"}
response = requests.post(self._TOKEN_URL, data=data, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -125,7 +124,7 @@ class GoogleOAuth(OAuth):
def get_raw_user_info(self, token: str):
headers = {"Authorization": f"Bearer {token}"}
response = requests.get(self._USER_INFO_URL, headers=headers)
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()

View File

@ -1,7 +1,7 @@
import urllib.parse
from typing import Any
import requests
import httpx
from flask_login import current_user
from sqlalchemy import select
@ -43,7 +43,7 @@ class NotionOAuth(OAuthDataSource):
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
response = requests.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response = httpx.post(self._TOKEN_URL, data=data, auth=auth, headers=headers)
response_json = response.json()
access_token = response_json.get("access_token")
@ -239,7 +239,7 @@ class NotionOAuth(OAuthDataSource):
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))
@ -254,7 +254,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response = httpx.get(url=f"{self._NOTION_BLOCK_SEARCH}/{block_id}", headers=headers)
response_json = response.json()
if response.status_code != 200:
message = response_json.get("message", "unknown error")
@ -270,7 +270,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.get(url=self._NOTION_BOT_USER, headers=headers)
response = httpx.get(url=self._NOTION_BOT_USER, headers=headers)
response_json = response.json()
if "object" in response_json and response_json["object"] == "user":
user_type = response_json["type"]
@ -294,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
}
response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response = httpx.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers)
response_json = response.json()
results.extend(response_json.get("results", []))

View File

@ -1,4 +1,4 @@
from typing import Any, Optional
from typing import Any
import orjson
@ -6,6 +6,6 @@ import orjson
def orjson_dumps(
obj: Any,
encoding: str = "utf-8",
option: Optional[int] = None,
option: int | None = None,
) -> str:
return orjson.dumps(obj, option=option).decode(encoding)

View File

@ -14,11 +14,11 @@ class PassportService:
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
except jwt.exceptions.ExpiredSignatureError:
except jwt.ExpiredSignatureError:
raise Unauthorized("Token has expired.")
except jwt.exceptions.InvalidSignatureError:
except jwt.InvalidSignatureError:
raise Unauthorized("Invalid token signature.")
except jwt.exceptions.DecodeError:
except jwt.DecodeError:
raise Unauthorized("Invalid token.")
except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors
except jwt.PyJWTError: # Catch-all for other JWT errors
raise Unauthorized("Invalid token.")

View File

@ -26,22 +26,22 @@ class SendGridClient:
to_email = To(_to)
subject = mail["subject"]
content = Content("text/html", mail["html"])
mail = Mail(from_email, to_email, subject, content)
mail_json = mail.get() # type: ignore
response = sg.client.mail.send.post(request_body=mail_json)
sg_mail = Mail(from_email, to_email, subject, content)
mail_json = sg_mail.get()
response = sg.client.mail.send.post(request_body=mail_json) # ty: ignore [call-non-callable]
logger.debug(response.status_code)
logger.debug(response.body)
logger.debug(response.headers)
except TimeoutError as e:
except TimeoutError:
logger.exception("SendGridClient Timeout occurred while sending email")
raise
except (UnauthorizedError, ForbiddenError) as e:
except (UnauthorizedError, ForbiddenError):
logger.exception(
"SendGridClient Authentication failed. "
"Verify that your credentials and the 'from' email address are correct"
)
raise
except Exception as e:
except Exception:
logger.exception("SendGridClient Unexpected error occurred while sending email to %s", _to)
raise

View File

@ -45,13 +45,13 @@ class SMTPClient:
msg.attach(MIMEText(mail["html"], "html"))
smtp.sendmail(self._from, mail["to"], msg.as_string())
except smtplib.SMTPException as e:
except smtplib.SMTPException:
logger.exception("SMTP error occurred")
raise
except TimeoutError as e:
except TimeoutError:
logger.exception("Timeout occurred while sending email")
raise
except Exception as e:
except Exception:
logger.exception("Unexpected error occurred while sending email to %s", mail["to"])
raise
finally:

9
api/libs/typing.py Normal file
View File

@ -0,0 +1,9 @@
from typing import TypeGuard
def is_str_dict(v: object) -> TypeGuard[dict[str, object]]:
return isinstance(v, dict)
def is_str(v: object) -> TypeGuard[str]:
return isinstance(v, str)