mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
feat: initialize user timezone and language from browser (#36170)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@ -3,7 +3,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from constants.languages import get_valid_language, languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
@ -15,11 +15,12 @@ from controllers.console.auth.error import (
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountNotFoundError, AccountRegisterError
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
from ..error import AccountInFreezeError, EmailSendIpLimitError
|
||||
from ..wraps import email_password_login_enabled, email_register_enabled, setup_required
|
||||
@ -40,12 +41,21 @@ class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return validate_timezone_string(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
|
||||
|
||||
@ -144,26 +154,32 @@ class EmailRegisterResetApi(Resource):
|
||||
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
account = self._create_new_account(
|
||||
email=normalized_email,
|
||||
password=args.password_confirm,
|
||||
timezone=args.timezone,
|
||||
language=args.language,
|
||||
)
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(normalized_email)
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
def _create_new_account(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> Account:
|
||||
try:
|
||||
account = AccountService.create_account_and_tenant(
|
||||
return AccountService.create_account_and_tenant(
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=timezone,
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
return account
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
@ -34,6 +34,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
@ -69,6 +70,14 @@ class EmailCodeLoginPayload(BaseModel):
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return validate_timezone_string(value)
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
|
||||
@ -288,6 +297,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=args.timezone,
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
|
||||
@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
@ -53,6 +54,31 @@ def get_oauth_providers():
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
def _validated_timezone(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return validate_timezone_string(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _validated_language(value: str | None) -> str | None:
|
||||
if value and value in languages:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _preferred_interface_language(language: str | None = None) -> str:
|
||||
if language:
|
||||
return language
|
||||
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
return preferred_lang
|
||||
return languages[0]
|
||||
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@console_ns.doc("oauth_login")
|
||||
@ -64,13 +90,19 @@ class OAuthLogin(Resource):
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
timezone = _validated_timezone(request.args.get("timezone") or None)
|
||||
language = _validated_language(request.args.get("language") or None)
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
|
||||
auth_url = oauth_provider.get_authorization_url(
|
||||
invite_token=invite_token,
|
||||
timezone=timezone,
|
||||
language=language,
|
||||
)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@ -96,9 +128,10 @@ class OAuthCallback(Resource):
|
||||
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
invite_token = None
|
||||
if state:
|
||||
invite_token = state
|
||||
oauth_state = decode_oauth_state(state)
|
||||
invite_token = oauth_state.get("invite_token")
|
||||
timezone = _validated_timezone(oauth_state.get("timezone"))
|
||||
language = _validated_language(oauth_state.get("language"))
|
||||
|
||||
if not code:
|
||||
return {"error": "Authorization code is required"}, 400
|
||||
@ -129,7 +162,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
def _generate_account(
|
||||
provider: str,
|
||||
user_info: OAuthUserInfo,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
@ -211,26 +249,19 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
"30 days and is temporarily unavailable for new account registration"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
interface_language = _preferred_interface_language(language)
|
||||
account = RegisterService.register(
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
interface_language = preferred_lang
|
||||
else:
|
||||
interface_language = languages[0]
|
||||
account.interface_language = interface_language
|
||||
db.session.commit()
|
||||
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class OAuthState(TypedDict, total=False):
|
||||
invite_token: str
|
||||
timezone: str
|
||||
language: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict):
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
@ -58,6 +68,37 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def encode_oauth_state(
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str | None:
|
||||
state: OAuthState = {}
|
||||
if invite_token:
|
||||
state["invite_token"] = invite_token
|
||||
if timezone:
|
||||
state["timezone"] = timezone
|
||||
if language:
|
||||
state["language"] = language
|
||||
if not state:
|
||||
return None
|
||||
|
||||
raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8")
|
||||
return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def decode_oauth_state(state: str | None) -> OAuthState:
|
||||
if not state:
|
||||
return {}
|
||||
|
||||
try:
|
||||
padded_state = state + "=" * (-len(state) % 4)
|
||||
raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8")
|
||||
return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state))
|
||||
except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError):
|
||||
return {}
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
@ -76,7 +117,12 @@ class OAuth:
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
@ -99,14 +145,20 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "user:email", # Request only basic user information
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
|
||||
if state:
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
@ -186,15 +238,21 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "openid email",
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
|
||||
if state:
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
|
||||
@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
| code | string | | Yes |
|
||||
| email | string | | Yes |
|
||||
| language | | | No |
|
||||
| timezone | | | No |
|
||||
| token | string | | Yes |
|
||||
|
||||
#### EmailPayload
|
||||
@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| language | | | No |
|
||||
| new_password | string | | Yes |
|
||||
| password_confirm | string | | Yes |
|
||||
| timezone | | | No |
|
||||
| token | string | | Yes |
|
||||
|
||||
#### EmailRegisterSendPayload
|
||||
|
||||
@ -29,6 +29,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.helper import timezone as validate_timezone
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@ -271,8 +272,9 @@ class AccountService:
|
||||
password: str | None = None,
|
||||
interface_theme: str = "light",
|
||||
is_setup: bool | None = False,
|
||||
timezone: str | None = None,
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
"""Create an account, preferring explicit user timezone over language-derived defaults."""
|
||||
if not FeatureService.get_system_features().is_allow_register and not is_setup:
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
@ -302,6 +304,10 @@ class AccountService:
|
||||
password_to_set = base64_password_hashed
|
||||
salt_to_set = base64_salt
|
||||
|
||||
resolved_timezone = language_timezone_mapping.get(interface_language, "UTC")
|
||||
if timezone is not None:
|
||||
resolved_timezone = validate_timezone(timezone)
|
||||
|
||||
account = Account(
|
||||
name=name,
|
||||
email=email,
|
||||
@ -309,7 +315,7 @@ class AccountService:
|
||||
password_salt=salt_to_set,
|
||||
interface_language=interface_language,
|
||||
interface_theme=interface_theme,
|
||||
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||
timezone=resolved_timezone,
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
@ -318,11 +324,15 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def create_account_and_tenant(
|
||||
email: str, name: str, interface_language: str, password: str | None = None
|
||||
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
"""Create an account and owner workspace."""
|
||||
account = AccountService.create_account(
|
||||
email=email, name=name, interface_language=interface_language, password=password
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -1474,8 +1484,8 @@ class RegisterService:
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
email,
|
||||
name,
|
||||
email: str,
|
||||
name: str,
|
||||
password: str | None = None,
|
||||
open_id: str | None = None,
|
||||
provider: str | None = None,
|
||||
@ -1483,16 +1493,19 @@ class RegisterService:
|
||||
status: AccountStatus | None = None,
|
||||
is_setup: bool | None = False,
|
||||
create_workspace_required: bool | None = True,
|
||||
timezone: str | None = None,
|
||||
) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
db.session.begin_nested()
|
||||
try:
|
||||
interface_language = get_valid_language(language)
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=get_valid_language(language),
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
timezone=timezone,
|
||||
)
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
@ -143,7 +143,118 @@ class TestEmailRegisterResetApi:
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!")
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
password="ValidPass123!",
|
||||
timezone=None,
|
||||
language=None,
|
||||
)
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_passes_timezone_to_new_account(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
password="ValidPass123!",
|
||||
timezone="Asia/Shanghai",
|
||||
language=None,
|
||||
)
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_passes_language_to_new_account(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
"language": "zh-Hans",
|
||||
},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
password="ValidPass123!",
|
||||
timezone=None,
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
@ -14,7 +14,7 @@ from controllers.console.auth.oauth import (
|
||||
_get_account_by_openid_or_email,
|
||||
get_oauth_providers,
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from libs.oauth import OAuthUserInfo, encode_oauth_state
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -101,7 +101,55 @@ class TestOAuthLogin:
|
||||
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=expected_token,
|
||||
timezone=None,
|
||||
language=None,
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_pass_timezone_to_oauth_state(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app: Flask,
|
||||
mock_oauth_provider,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone="Asia/Shanghai",
|
||||
language=None,
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_pass_language_to_oauth_state(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app: Flask,
|
||||
mock_oauth_provider,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github?language=zh-Hans"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone=None,
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -229,7 +277,8 @@ class TestOAuthCallback:
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai")
|
||||
with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
@ -488,7 +537,13 @@ class TestAccountGeneration:
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -515,7 +570,75 @@ class TestAccountGeneration:
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
email="upper@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="en-US",
|
||||
timezone=None,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_browser_timezone(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app: Flask,
|
||||
user_info,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
|
||||
_generate_account("github", user_info, timezone="Asia/Shanghai")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_state_language(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app: Flask,
|
||||
user_info,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
_generate_account("github", user_info, language="zh-Hans")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone=None,
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
|
||||
@ -0,0 +1,40 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.auth.email_register import EmailRegisterResetApi, EmailRegisterResetPayload
|
||||
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant")
|
||||
def test_create_new_account_uses_requested_language(mock_create_account):
|
||||
account = MagicMock()
|
||||
mock_create_account.return_value = account
|
||||
|
||||
result = EmailRegisterResetApi()._create_new_account(
|
||||
"invitee@example.com",
|
||||
"ValidPass123!",
|
||||
timezone="Asia/Shanghai",
|
||||
language="zh-Hans",
|
||||
)
|
||||
|
||||
assert result is account
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
name="invitee@example.com",
|
||||
password="ValidPass123!",
|
||||
interface_language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
|
||||
def test_reset_payload_rejects_invalid_timezone():
|
||||
with pytest.raises(ValidationError):
|
||||
EmailRegisterResetPayload.model_validate(
|
||||
{
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
"timezone": "",
|
||||
}
|
||||
)
|
||||
@ -13,9 +13,10 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.console.auth.error import EmailCodeError, InvalidEmailError, InvalidTokenError
|
||||
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginSendEmailApi
|
||||
from controllers.console.auth.login import EmailCodeLoginApi, EmailCodeLoginPayload, EmailCodeLoginSendEmailApi
|
||||
from controllers.console.error import (
|
||||
AccountInFreezeError,
|
||||
AccountNotFound,
|
||||
@ -31,6 +32,18 @@ def encode_code(code: str) -> str:
|
||||
return base64.b64encode(code.encode("utf-8")).decode()
|
||||
|
||||
|
||||
def test_email_code_login_payload_rejects_invalid_timezone():
|
||||
with pytest.raises(ValidationError):
|
||||
EmailCodeLoginPayload.model_validate(
|
||||
{
|
||||
"email": "newuser@example.com",
|
||||
"code": "123456",
|
||||
"token": "token-123",
|
||||
"timezone": "",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestEmailCodeLoginSendEmailApi:
|
||||
"""Test cases for sending email verification codes."""
|
||||
|
||||
@ -342,6 +355,7 @@ class TestEmailCodeLoginApi:
|
||||
"code": encode_code("123456"),
|
||||
"token": "valid_token",
|
||||
"language": "en-US",
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
@ -349,7 +363,12 @@ class TestEmailCodeLoginApi:
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
mock_create_account.assert_called_once()
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="newuser@example.com",
|
||||
name="newuser@example.com",
|
||||
interface_language="en-US",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
|
||||
@ -0,0 +1,123 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import OAuthLogin, _generate_account
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_oauth_login_passes_language_and_timezone_to_authorization_url(
|
||||
mock_get_oauth_providers,
|
||||
mock_redirect,
|
||||
app: Flask,
|
||||
):
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
|
||||
mock_get_oauth_providers.return_value = {"github": oauth_provider}
|
||||
|
||||
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
|
||||
OAuthLogin().get("github")
|
||||
|
||||
oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone="Asia/Shanghai",
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_registers_with_browser_timezone(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_link_account,
|
||||
app: Flask,
|
||||
):
|
||||
account = MagicMock()
|
||||
mock_register_service.register.return_value = account
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
|
||||
result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai")
|
||||
|
||||
assert result is account
|
||||
assert oauth_new_user is True
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="user@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="github-123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_prefers_state_language_over_accept_language(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_link_account,
|
||||
app: Flask,
|
||||
):
|
||||
account = MagicMock()
|
||||
mock_register_service.register.return_value = account
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
_generate_account("github", user_info, language="zh-Hans")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="user@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="github-123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone=None,
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_rejects_new_user_when_registration_disabled(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_config,
|
||||
app: Flask,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = False
|
||||
mock_config.BILLING_ENABLED = False
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented():
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({})
|
||||
|
||||
|
||||
def test_oauth_state_round_trips_invite_token_timezone_and_language():
|
||||
state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans")
|
||||
|
||||
assert decode_oauth_state(state) == {
|
||||
"invite_token": "invite-123",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"language": "zh-Hans",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_state_returns_empty_payload_for_invalid_state():
|
||||
assert decode_oauth_state("invalid-state") == {}
|
||||
|
||||
@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
|
||||
|
||||
class BaseOAuthTest:
|
||||
@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
("invite_token", "timezone", "language", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
(None, None, None, None),
|
||||
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
|
||||
("", None, None, None),
|
||||
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
|
||||
(None, None, "zh-Hans", {"language": "zh-Hans"}),
|
||||
(
|
||||
"test_invite_token",
|
||||
"Asia/Shanghai",
|
||||
"zh-Hans",
|
||||
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
def test_should_generate_authorization_url_correctly(
|
||||
self, oauth, oauth_config, invite_token, timezone, language, expected_state
|
||||
):
|
||||
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert params["scope"][0] == "user:email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
assert decode_oauth_state(params["state"][0]) == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
("invite_token", "timezone", "language", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
(None, None, None, None),
|
||||
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
|
||||
("", None, None, None),
|
||||
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
|
||||
(None, None, "zh-Hans", {"language": "zh-Hans"}),
|
||||
(
|
||||
"test_invite_token",
|
||||
"Asia/Shanghai",
|
||||
"zh-Hans",
|
||||
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
def test_should_generate_authorization_url_correctly(
|
||||
self, oauth, oauth_config, invite_token, timezone, language, expected_state
|
||||
):
|
||||
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
assert params["scope"][0] == "openid email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
assert decode_oauth_state(params["state"][0]) == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ class TestAccountService:
|
||||
assert result.interface_theme == "light"
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
assert result.timezone is not None
|
||||
assert result.timezone == "America/New_York"
|
||||
|
||||
# Verify database operations
|
||||
mock_db_dependencies["db"].session.add.assert_called_once()
|
||||
@ -271,7 +271,28 @@ class TestAccountService:
|
||||
assert added_account.interface_theme == "light"
|
||||
assert added_account.password is not None
|
||||
assert added_account.password_salt is not None
|
||||
assert added_account.timezone is not None
|
||||
assert added_account.timezone == "America/New_York"
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_create_account_uses_explicit_timezone(
|
||||
self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
|
||||
):
|
||||
"""Test account creation prefers explicit browser timezone."""
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
mock_password_dependencies["hash_password"].return_value = b"hashed_password"
|
||||
|
||||
result = AccountService.create_account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert result.timezone == "Asia/Shanghai"
|
||||
added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
|
||||
assert added_account.timezone == "Asia/Shanghai"
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_create_account_registration_disabled(self, mock_external_service_dependencies):
|
||||
@ -1221,6 +1242,7 @@ class TestRegisterService:
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
is_setup=False,
|
||||
timezone=None,
|
||||
)
|
||||
mock_create_tenant.assert_called_once_with("Test User's Workspace")
|
||||
mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner")
|
||||
|
||||
Reference in New Issue
Block a user