Compare commits

..

8 Commits

28 changed files with 1034 additions and 89 deletions

View File

@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
# GraphEngine Worker Pool Configuration
# Minimum number of workers per GraphEngine instance (default: 1)
GRAPH_ENGINE_MIN_WORKERS=1
GRAPH_ENGINE_MIN_WORKERS=3
# Maximum number of workers per GraphEngine instance (default: 10)
GRAPH_ENGINE_MAX_WORKERS=10
# Queue depth threshold that triggers worker scale up (default: 3)

View File

@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
# GraphEngine Worker Pool Configuration
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
description="Minimum number of workers per GraphEngine instance",
default=1,
default=3,
)
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(

View File

@ -3,7 +3,7 @@ from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from configs import dify_config
from constants.languages import languages
from constants.languages import get_valid_language, languages
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
@ -15,6 +15,7 @@ from controllers.console.auth.error import (
PasswordMismatchError,
)
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.password import valid_password
from models import Account
from services.account_service import AccountService
@ -40,12 +41,19 @@ class EmailRegisterResetPayload(BaseModel):
token: str = Field(...)
new_password: str = Field(...)
password_confirm: str = Field(...)
language: str | None = Field(default=None)
timezone: str | None = Field(default=None)
@field_validator("new_password", "password_confirm")
@classmethod
def validate_password(cls, value: str) -> str:
return valid_password(value)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str | None) -> str | None:
return validate_timezone_string(value) if value else value
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
@ -145,7 +153,15 @@ class EmailRegisterResetApi(Resource):
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if args.timezone or args.language:
account = self._create_new_account(
normalized_email,
args.password_confirm,
args.timezone,
args.language,
)
else:
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
@ -153,7 +169,13 @@ class EmailRegisterResetApi(Resource):
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email: str, password: str) -> Account | None:
def _create_new_account(
self,
email: str,
password: str,
timezone: str | None = None,
language: str | None = None,
) -> Account | None:
# Create new account if allowed
account = None
try:
@ -161,7 +183,8 @@ class EmailRegisterResetApi(Resource):
email=email,
name=email,
password=password,
interface_language=languages[0],
interface_language=get_valid_language(language),
timezone=timezone,
)
except AccountRegisterError:
raise AccountInFreezeError()

View File

@ -3,7 +3,7 @@ import logging
import flask_login
from flask import make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from werkzeug.exceptions import Unauthorized
import services
@ -34,6 +34,7 @@ from controllers.console.wraps import (
)
from events.tenant_event import tenant_was_created
from libs.helper import EmailStr, extract_remote_ip
from libs.helper import timezone as validate_timezone_string
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
@ -69,6 +70,12 @@ class EmailCodeLoginPayload(BaseModel):
code: str = Field(...)
token: str = Field(...)
language: str | None = Field(default=None)
timezone: str | None = Field(default=None)
@field_validator("timezone")
@classmethod
def validate_timezone(cls, value: str | None) -> str | None:
return validate_timezone_string(value) if value else value
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
@ -288,6 +295,7 @@ class EmailCodeLoginApi(Resource):
email=user_email,
name=user_email,
interface_language=get_valid_language(language),
timezone=args.timezone,
)
except WorkSpaceNotAllowedCreateError:
raise NotAllowedCreateWorkspace()

View File

@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.helper import timezone as validate_timezone_string
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
@ -53,6 +54,31 @@ def get_oauth_providers():
return OAUTH_PROVIDERS
def _validated_timezone(value: str | None) -> str | None:
if not value:
return None
try:
return validate_timezone_string(value)
except ValueError:
return None
def _validated_language(value: str | None) -> str | None:
if value and value in languages:
return value
return None
def _preferred_interface_language(language: str | None = None) -> str:
if language:
return language
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
return preferred_lang
return languages[0]
@console_ns.route("/oauth/login/<provider>")
class OAuthLogin(Resource):
@console_ns.doc("oauth_login")
@ -64,13 +90,19 @@ class OAuthLogin(Resource):
@console_ns.response(400, "Invalid provider")
def get(self, provider: str):
invite_token = request.args.get("invite_token") or None
timezone = _validated_timezone(request.args.get("timezone") or None)
language = _validated_language(request.args.get("language") or None)
OAUTH_PROVIDERS = get_oauth_providers()
with current_app.app_context():
oauth_provider = OAUTH_PROVIDERS.get(provider)
if not oauth_provider:
return {"error": "Invalid provider"}, 400
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
auth_url = oauth_provider.get_authorization_url(
invite_token=invite_token,
timezone=timezone,
language=language,
)
return redirect(auth_url)
@ -96,9 +128,10 @@ class OAuthCallback(Resource):
code = request.args.get("code")
state = request.args.get("state")
invite_token = None
if state:
invite_token = state
oauth_state = decode_oauth_state(state)
invite_token = oauth_state.get("invite_token")
timezone = _validated_timezone(oauth_state.get("timezone"))
language = _validated_language(oauth_state.get("language"))
if not code:
return {"error": "Authorization code is required"}, 400
@ -129,7 +162,7 @@ class OAuthCallback(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
try:
account, oauth_new_user = _generate_account(provider, user_info)
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
except AccountNotFoundError:
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
return account
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
def _generate_account(
provider: str,
user_info: OAuthUserInfo,
timezone: str | None = None,
language: str | None = None,
) -> tuple[Account, bool]:
# Get account by openid or email.
account = _get_account_by_openid_or_email(provider, user_info)
oauth_new_user = False
@ -211,25 +249,28 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
"30 days and is temporarily unavailable for new account registration"
)
)
else:
raise AccountRegisterError(description=("Invalid email or password"))
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
)
# Set interface language
preferred_lang = request.accept_languages.best_match(languages)
if preferred_lang and preferred_lang in languages:
interface_language = preferred_lang
interface_language = _preferred_interface_language(language)
if timezone:
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
language=interface_language,
timezone=timezone,
)
else:
interface_language = languages[0]
account.interface_language = interface_language
db.session.commit()
account = RegisterService.register(
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
language=interface_language,
)
# Link account
AccountService.link_account_integrate(provider, user_info.id, account)

View File

@ -1,3 +1,6 @@
import base64
import binascii
import json
import logging
import urllib.parse
from dataclasses import dataclass
@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False):
access_token: str
class OAuthState(TypedDict, total=False):
invite_token: str
timezone: str
language: str
class GitHubEmailRecord(TypedDict, total=False):
email: str
primary: bool
@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict):
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState)
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
@ -58,6 +68,37 @@ class OAuthUserInfo:
email: str
def encode_oauth_state(
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str | None:
state: OAuthState = {}
if invite_token:
state["invite_token"] = invite_token
if timezone:
state["timezone"] = timezone
if language:
state["language"] = language
if not state:
return None
raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8")
return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=")
def decode_oauth_state(state: str | None) -> OAuthState:
if not state:
return {}
try:
padded_state = state + "=" * (-len(state) % 4)
raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8")
return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state))
except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError):
return {"invite_token": state}
def _json_object(response: httpx.Response) -> JsonObject:
return JSON_OBJECT_ADAPTER.validate_python(response.json())
@ -76,7 +117,12 @@ class OAuth:
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
raise NotImplementedError()
def get_access_token(self, code: str) -> str:
@ -99,14 +145,20 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"scope": "user:email", # Request only basic user information
}
if invite_token:
params["state"] = invite_token
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
if state:
params["state"] = state
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str) -> str:
@ -186,15 +238,21 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None) -> str:
def get_authorization_url(
self,
invite_token: str | None = None,
timezone: str | None = None,
language: str | None = None,
) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
"redirect_uri": self.redirect_uri,
"scope": "openid email",
}
if invite_token:
params["state"] = invite_token
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
if state:
params["state"] = state
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str) -> str:

View File

@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive.
| code | string | | Yes |
| email | string | | Yes |
| language | | | No |
| timezone | | | No |
| token | string | | Yes |
#### EmailPayload
@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive.
| Name | Type | Description | Required |
| ---- | ---- | ----------- | -------- |
| language | | | No |
| new_password | string | | Yes |
| password_confirm | string | | Yes |
| timezone | | | No |
| token | string | | Yes |
#### EmailRegisterSendPayload

View File

@ -29,6 +29,7 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
from libs.datetime_utils import naive_utc_now
from libs.helper import RateLimiter, TokenManager
from libs.helper import timezone as validate_timezone
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
@ -271,8 +272,9 @@ class AccountService:
password: str | None = None,
interface_theme: str = "light",
is_setup: bool | None = False,
timezone: str | None = None,
) -> Account:
"""create account"""
"""Create an account, preferring explicit user timezone over language-derived defaults."""
if not FeatureService.get_system_features().is_allow_register and not is_setup:
from controllers.console.error import AccountNotFound
@ -309,7 +311,9 @@ class AccountService:
password_salt=salt_to_set,
interface_language=interface_language,
interface_theme=interface_theme,
timezone=language_timezone_mapping.get(interface_language, "UTC"),
timezone=(
validate_timezone(timezone) if timezone else language_timezone_mapping.get(interface_language, "UTC")
),
)
db.session.add(account)
@ -318,12 +322,24 @@ class AccountService:
@staticmethod
def create_account_and_tenant(
email: str, name: str, interface_language: str, password: str | None = None
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password
)
if timezone:
account = AccountService.create_account(
email=email,
name=name,
interface_language=interface_language,
password=password,
timezone=timezone,
)
else:
account = AccountService.create_account(
email=email,
name=name,
interface_language=interface_language,
password=password,
)
try:
TenantService.create_owner_tenant_if_not_exist(account=account)
@ -1483,17 +1499,29 @@ class RegisterService:
status: AccountStatus | None = None,
is_setup: bool | None = False,
create_workspace_required: bool | None = True,
timezone: str | None = None,
) -> Account:
db.session.begin_nested()
"""Register account"""
try:
account = AccountService.create_account(
email=email,
name=name,
interface_language=get_valid_language(language),
password=password,
is_setup=is_setup,
)
interface_language = get_valid_language(language)
if timezone:
account = AccountService.create_account(
email=email,
name=name,
interface_language=interface_language,
password=password,
is_setup=is_setup,
timezone=timezone,
)
else:
account = AccountService.create_account(
email=email,
name=name,
interface_language=interface_language,
password=password,
is_setup=is_setup,
)
account.status = status or AccountStatus.ACTIVE
account.initialized_at = naive_utc_now()

View File

@ -148,6 +148,102 @@ class TestEmailRegisterResetApi:
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_passes_timezone_to_new_account(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_create_account,
mock_login,
mock_reset_login_rate,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
"timezone": "Asia/Shanghai",
},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!", "Asia/Shanghai", None)
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
@patch("controllers.console.auth.email_register.AccountService.login")
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
def test_reset_passes_language_to_new_account(
self,
mock_extract_ip,
mock_get_data,
mock_revoke_token,
mock_get_account,
mock_create_account,
mock_login,
mock_reset_login_rate,
app: Flask,
):
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
mock_create_account.return_value = MagicMock()
token_pair = MagicMock()
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
mock_login.return_value = token_pair
mock_get_account.return_value = None
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
with (
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
):
with app.test_request_context(
"/email-register",
method="POST",
json={
"token": "token-123",
"new_password": "ValidPass123!",
"password_confirm": "ValidPass123!",
"language": "zh-Hans",
},
):
response = EmailRegisterResetApi().post()
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!", None, "zh-Hans")
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
mock_revoke_token.assert_called_once_with("token-123")
mock_extract_ip.assert_called_once()
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
"""Test that case fallback tries lowercase when exact match fails."""

View File

@ -14,7 +14,7 @@ from controllers.console.auth.oauth import (
_get_account_by_openid_or_email,
get_oauth_providers,
)
from libs.oauth import OAuthUserInfo
from libs.oauth import OAuthUserInfo, encode_oauth_state
from models.account import AccountStatus
from services.account_service import AccountService
from services.errors.account import AccountRegisterError
@ -101,7 +101,55 @@ class TestOAuthLogin:
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=expected_token,
timezone=None,
language=None,
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_pass_timezone_to_oauth_state(
self,
mock_redirect,
mock_get_providers,
resource,
app: Flask,
mock_oauth_provider,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone="Asia/Shanghai",
language=None,
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@patch("controllers.console.auth.oauth.get_oauth_providers")
@patch("controllers.console.auth.oauth.redirect")
def test_should_pass_language_to_oauth_state(
self,
mock_redirect,
mock_get_providers,
resource,
app: Flask,
mock_oauth_provider,
):
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
with app.test_request_context("/auth/oauth/github?language=zh-Hans"):
resource.get("github")
mock_oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone=None,
language="zh-Hans",
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
@pytest.mark.parametrize(
@ -229,7 +277,8 @@ class TestOAuthCallback:
mock_register_service.is_valid_invite_token.return_value = True
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai")
with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"):
resource.get("github")
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
@ -488,7 +537,12 @@ class TestAccountGeneration:
if should_create:
mock_register_service.register.assert_called_once_with(
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="en-US",
)
else:
mock_register_service.register.assert_not_called()
@ -515,7 +569,73 @@ class TestAccountGeneration:
_generate_account("github", user_info)
mock_register_service.register.assert_called_once_with(
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
email="upper@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="en-US",
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_browser_timezone(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
_generate_account("github", user_info, timezone="Asia/Shanghai")
mock_register_service.register.assert_called_once_with(
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.AccountService")
@patch("controllers.console.auth.oauth.TenantService")
def test_should_register_with_state_language(
self,
mock_tenant_service,
mock_account_service,
mock_register_service,
mock_feature_service,
mock_get_account,
app: Flask,
user_info,
):
mock_feature_service.get_system_features.return_value.is_allow_register = True
mock_register_service.register.return_value = MagicMock()
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
_generate_account("github", user_info, language="zh-Hans")
mock_register_service.register.assert_called_once_with(
email="test@example.com",
name="Test User",
password=None,
open_id="123",
provider="github",
language="zh-Hans",
)
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")

View File

@ -0,0 +1,25 @@
from unittest.mock import MagicMock, patch
from controllers.console.auth.email_register import EmailRegisterResetApi
@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant")
def test_create_new_account_uses_requested_language(mock_create_account):
account = MagicMock()
mock_create_account.return_value = account
result = EmailRegisterResetApi()._create_new_account(
"invitee@example.com",
"ValidPass123!",
timezone="Asia/Shanghai",
language="zh-Hans",
)
assert result is account
mock_create_account.assert_called_once_with(
email="invitee@example.com",
name="invitee@example.com",
password="ValidPass123!",
interface_language="zh-Hans",
timezone="Asia/Shanghai",
)

View File

@ -342,6 +342,7 @@ class TestEmailCodeLoginApi:
"code": encode_code("123456"),
"token": "valid_token",
"language": "en-US",
"timezone": "Asia/Shanghai",
},
):
api = EmailCodeLoginApi()
@ -349,7 +350,12 @@ class TestEmailCodeLoginApi:
# Assert
assert response.json["result"] == "success"
mock_create_account.assert_called_once()
mock_create_account.assert_called_once_with(
email="newuser@example.com",
name="newuser@example.com",
interface_language="en-US",
timezone="Asia/Shanghai",
)
@patch("controllers.console.wraps.db")
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")

View File

@ -0,0 +1,122 @@
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.console.auth.oauth import OAuthLogin, _generate_account
from libs.oauth import OAuthUserInfo
from services.errors.account import AccountRegisterError
@pytest.fixture
def app() -> Flask:
app = Flask(__name__)
app.config["TESTING"] = True
return app
@patch("controllers.console.auth.oauth.redirect")
@patch("controllers.console.auth.oauth.get_oauth_providers")
def test_oauth_login_passes_language_and_timezone_to_authorization_url(
mock_get_oauth_providers,
mock_redirect,
app: Flask,
):
oauth_provider = MagicMock()
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
mock_get_oauth_providers.return_value = {"github": oauth_provider}
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
OAuthLogin().get("github")
oauth_provider.get_authorization_url.assert_called_once_with(
invite_token=None,
timezone="Asia/Shanghai",
language="zh-Hans",
)
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_registers_with_browser_timezone(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_link_account,
app: Flask,
):
account = MagicMock()
mock_register_service.register.return_value = account
mock_feature_service.get_system_features.return_value.is_allow_register = True
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai")
assert result is account
assert oauth_new_user is True
mock_register_service.register.assert_called_once_with(
email="user@example.com",
name="Test User",
password=None,
open_id="github-123",
provider="github",
language="zh-Hans",
timezone="Asia/Shanghai",
)
mock_link_account.assert_called_once_with("github", "github-123", account)
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_prefers_state_language_over_accept_language(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_link_account,
app: Flask,
):
account = MagicMock()
mock_register_service.register.return_value = account
mock_feature_service.get_system_features.return_value.is_allow_register = True
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
_generate_account("github", user_info, language="zh-Hans")
mock_register_service.register.assert_called_once_with(
email="user@example.com",
name="Test User",
password=None,
open_id="github-123",
provider="github",
language="zh-Hans",
)
mock_link_account.assert_called_once_with("github", "github-123", account)
@patch("controllers.console.auth.oauth.dify_config")
@patch("controllers.console.auth.oauth.RegisterService")
@patch("controllers.console.auth.oauth.FeatureService")
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
def test_generate_account_rejects_new_user_when_registration_disabled(
mock_get_account,
mock_feature_service,
mock_register_service,
mock_config,
app: Flask,
):
mock_feature_service.get_system_features.return_value.is_allow_register = False
mock_config.BILLING_ENABLED = False
user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com")
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
with pytest.raises(AccountRegisterError):
_generate_account("github", user_info)
mock_register_service.register.assert_not_called()

View File

@ -297,7 +297,7 @@ class TableTestRunner:
max_workers: int = 4,
enable_logging: bool = False,
log_level: str = "INFO",
graph_engine_min_workers: int = 1,
graph_engine_min_workers: int = 3,
graph_engine_max_workers: int = 1,
graph_engine_scale_up_threshold: int = 5,
graph_engine_scale_down_idle_time: float = 30.0,
@ -310,7 +310,7 @@ class TableTestRunner:
max_workers: Maximum number of parallel workers for test execution
enable_logging: Enable detailed logging
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
graph_engine_min_workers: Minimum workers for GraphEngine (default: 1)
graph_engine_min_workers: Minimum workers for GraphEngine (default: 3)
graph_engine_max_workers: Maximum workers for GraphEngine (default: 1)
graph_engine_scale_up_threshold: Queue depth to trigger scale up
graph_engine_scale_down_idle_time: Idle time before scaling down

View File

@ -1,6 +1,6 @@
import pytest
from libs.oauth import OAuth
from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state
def test_oauth_base_methods_raise_not_implemented():
@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented():
with pytest.raises(NotImplementedError):
oauth._transform_user_info({})
def test_oauth_state_round_trips_invite_token_timezone_and_language():
state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans")
assert decode_oauth_state(state) == {
"invite_token": "invite-123",
"timezone": "Asia/Shanghai",
"language": "zh-Hans",
}
def test_oauth_state_keeps_legacy_raw_invite_token_compatible():
assert decode_oauth_state("legacy-invite-token") == {"invite_token": "legacy-invite-token"}

View File

@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
import httpx
import pytest
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
class BaseOAuthTest:
@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest):
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
("invite_token", "timezone", "language", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
(None, None, None, None),
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
("", None, None, None),
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
(None, None, "zh-Hans", {"language": "zh-Hans"}),
(
"test_invite_token",
"Asia/Shanghai",
"zh-Hans",
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
def test_should_generate_authorization_url_correctly(
self, oauth, oauth_config, invite_token, timezone, language, expected_state
):
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest):
assert params["scope"][0] == "user:email"
if expected_state:
assert params["state"][0] == expected_state
assert decode_oauth_state(params["state"][0]) == expected_state
else:
assert "state" not in params
@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest):
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
@pytest.mark.parametrize(
("invite_token", "expected_state"),
("invite_token", "timezone", "language", "expected_state"),
[
(None, None),
("test_invite_token", "test_invite_token"),
("", None),
(None, None, None, None),
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
("", None, None, None),
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
(None, None, "zh-Hans", {"language": "zh-Hans"}),
(
"test_invite_token",
"Asia/Shanghai",
"zh-Hans",
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
),
],
)
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
url = oauth.get_authorization_url(invite_token)
def test_should_generate_authorization_url_correctly(
self, oauth, oauth_config, invite_token, timezone, language, expected_state
):
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
parsed, params = self.parse_auth_url(url)
assert parsed.scheme == "https"
@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest):
assert params["scope"][0] == "openid email"
if expected_state:
assert params["state"][0] == expected_state
assert decode_oauth_state(params["state"][0]) == expected_state
else:
assert "state" not in params

View File

@ -260,7 +260,7 @@ class TestAccountService:
assert result.interface_theme == "light"
assert result.password is not None
assert result.password_salt is not None
assert result.timezone is not None
assert result.timezone == "America/New_York"
# Verify database operations
mock_db_dependencies["db"].session.add.assert_called_once()
@ -271,7 +271,28 @@ class TestAccountService:
assert added_account.interface_theme == "light"
assert added_account.password is not None
assert added_account.password_salt is not None
assert added_account.timezone is not None
assert added_account.timezone == "America/New_York"
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_create_account_uses_explicit_timezone(
self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
):
"""Test account creation prefers explicit browser timezone."""
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_password_dependencies["hash_password"].return_value = b"hashed_password"
result = AccountService.create_account(
email="test@example.com",
name="Test User",
interface_language="en-US",
password="password123",
timezone="Asia/Shanghai",
)
assert result.timezone == "Asia/Shanghai"
added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
assert added_account.timezone == "Asia/Shanghai"
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_create_account_registration_disabled(self, mock_external_service_dependencies):

View File

@ -177,7 +177,7 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
WORKFLOW_CALL_MAX_DEPTH=5
MAX_VARIABLE_SIZE=204800
WORKFLOW_FILE_UPLOAD_LIMIT=10
GRAPH_ENGINE_MIN_WORKERS=1
GRAPH_ENGINE_MIN_WORKERS=3
GRAPH_ENGINE_MAX_WORKERS=10
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0

View File

@ -13,6 +13,7 @@ import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common'
import { encryptVerificationCode } from '@/utils/encryption'
import { getBrowserTimezone } from '@/utils/timezone'
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
export default function CheckCode() {
@ -39,7 +40,13 @@ export default function CheckCode() {
return
}
setIsLoading(true)
const ret = await emailLoginWithCode({ email, code: encryptVerificationCode(code), token, language })
const ret = await emailLoginWithCode({
email,
code: encryptVerificationCode(code),
token,
language,
timezone: getBrowserTimezone(),
})
if (ret.result === 'success') {
// Track login success event
trackEvent('user_login_success', {

View File

@ -0,0 +1,86 @@
import { render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useSearchParams } from '@/next/navigation'
import { getBrowserTimezone } from '@/utils/timezone'
import SocialAuth from '../social-auth'
vi.mock('@/next/navigation', () => ({
useSearchParams: vi.fn(),
}))
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
}))
const mockUseSearchParams = vi.mocked(useSearchParams)
const mockUseLocale = vi.mocked(useLocale)
const mockGetBrowserTimezone = vi.mocked(getBrowserTimezone)
describe('SocialAuth', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType<typeof useSearchParams>)
mockUseLocale.mockReturnValue('zh-Hans')
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
})
describe('Rendering', () => {
it('should render oauth provider links', () => {
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toBeInTheDocument()
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toBeInTheDocument()
})
})
describe('OAuth params', () => {
it('should include browser timezone and locale in oauth links', () => {
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
'href',
expect.stringContaining('timezone=Asia%2FShanghai'),
)
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
'href',
expect.stringContaining('language=zh-Hans'),
)
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
'href',
expect.stringContaining('timezone=Asia%2FShanghai'),
)
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
'href',
expect.stringContaining('language=zh-Hans'),
)
})
it('should preserve invite token when adding timezone', () => {
mockUseSearchParams.mockReturnValue(
new URLSearchParams('invite_token=invite-123') as unknown as ReturnType<typeof useSearchParams>,
)
render(<SocialAuth />)
const githubLink = screen.getByRole('link', { name: 'login.withGitHub' })
expect(githubLink).toHaveAttribute('href', expect.stringContaining('invite_token=invite-123'))
expect(githubLink).toHaveAttribute('href', expect.stringContaining('timezone=Asia%2FShanghai'))
expect(githubLink).toHaveAttribute('href', expect.stringContaining('language=zh-Hans'))
})
})
describe('Edge Cases', () => {
it('should omit timezone when browser timezone is unavailable', () => {
mockGetBrowserTimezone.mockReturnValue(undefined)
render(<SocialAuth />)
expect(screen.getByRole('link', { name: 'login.withGitHub' }).getAttribute('href')).not.toContain('timezone=')
})
})
})

View File

@ -2,8 +2,10 @@ import { Button } from '@langgenius/dify-ui/button'
import { cn } from '@langgenius/dify-ui/cn'
import { useTranslation } from 'react-i18next'
import { API_PREFIX } from '@/config'
import { useLocale } from '@/context/i18n'
import { useSearchParams } from '@/next/navigation'
import { getPurifyHref } from '@/utils'
import { getBrowserTimezone } from '@/utils/timezone'
import style from '../page.module.css'
type SocialAuthProps = {
@ -13,11 +15,19 @@ type SocialAuthProps = {
export default function SocialAuth(props: SocialAuthProps) {
const { t } = useTranslation()
const searchParams = useSearchParams()
const locale = useLocale()
const getOAuthLink = (href: string) => {
const url = getPurifyHref(`${API_PREFIX}${href}`)
if (searchParams.has('invite_token'))
return `${url}?${searchParams.toString()}`
const params = new URLSearchParams(searchParams.toString())
const timezone = getBrowserTimezone()
if (timezone)
params.set('timezone', timezone)
params.set('language', locale)
const query = params.toString()
if (query)
return `${url}?${query}`
return url
}

View File

@ -0,0 +1,139 @@
import type { MockedFunction } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { activateMember } from '@/service/common'
import { useInvitationCheck } from '@/service/use-common'
import { getBrowserTimezone } from '@/utils/timezone'
import InviteSettingsPage from '../page'
vi.mock('@tanstack/react-query', async () => {
const actual = await vi.importActual<typeof import('@tanstack/react-query')>('@tanstack/react-query')
return {
...actual,
useSuspenseQuery: vi.fn(() => ({
data: {
branding: {
enabled: true,
},
},
})),
}
})
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/i18n-config', () => ({
i18n: {
defaultLocale: 'en-US',
},
setLocaleOnClient: vi.fn(() => Promise.resolve()),
}))
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
useSearchParams: vi.fn(),
}))
vi.mock('@/service/common', () => ({
activateMember: vi.fn(),
}))
vi.mock('@/service/use-common', () => ({
useInvitationCheck: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
timezones: [
{ value: 'Asia/Shanghai', name: 'Asia/Shanghai' },
{ value: 'America/Los_Angeles', name: 'America/Los_Angeles' },
],
}))
vi.mock('../utils/post-login-redirect', () => ({
resolvePostLoginRedirect: vi.fn(() => null),
}))
const mockReplace = vi.fn()
const mockRefetch = vi.fn()
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
const mockActivateMember = activateMember as unknown as MockedFunction<typeof activateMember>
const mockUseInvitationCheck = useInvitationCheck as unknown as MockedFunction<typeof useInvitationCheck>
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
describe('InviteSettingsPage', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseLocale.mockReturnValue('zh-Hans')
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
mockUseSearchParams.mockReturnValue(
new URLSearchParams('invite_token=invite-token') as unknown as ReturnType<typeof useSearchParams>,
)
mockUseInvitationCheck.mockReturnValue({
data: {
is_valid: true,
data: {
workspace_name: 'Acme',
workspace_id: 'workspace-id',
email: 'invitee@example.com',
},
},
refetch: mockRefetch,
} as unknown as ReturnType<typeof useInvitationCheck>)
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
mockActivateMember.mockResolvedValue({ result: 'success' })
})
describe('Activation payload', () => {
it('should default language to the current UI locale', async () => {
render(<InviteSettingsPage />)
fireEvent.change(screen.getByLabelText('login.name'), {
target: { value: 'Invitee' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
await waitFor(() => {
expect(mockActivateMember).toHaveBeenCalledWith({
url: '/activate',
body: {
token: 'invite-token',
name: 'Invitee',
interface_language: 'zh-Hans',
timezone: 'Asia/Shanghai',
},
})
})
})
it('should fall back to configured default locale when current locale is unsupported', async () => {
mockUseLocale.mockReturnValue('unsupported-locale' as ReturnType<typeof useLocale>)
render(<InviteSettingsPage />)
fireEvent.change(screen.getByLabelText('login.name'), {
target: { value: 'Invitee' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
await waitFor(() => {
expect(mockActivateMember).toHaveBeenCalledWith({
url: '/activate',
body: {
token: 'invite-token',
name: 'Invitee',
interface_language: 'en-US',
timezone: 'Asia/Shanghai',
},
})
})
})
})
})

View File

@ -11,14 +11,15 @@ import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Loading from '@/app/components/base/loading'
import { LICENSE_LINK } from '@/constants/link'
import { setLocaleOnClient } from '@/i18n-config'
import { languages, LanguagesSupported } from '@/i18n-config/language'
import { useLocale } from '@/context/i18n'
import { i18n, setLocaleOnClient } from '@/i18n-config'
import { languages } from '@/i18n-config/language'
import Link from '@/next/link'
import { useRouter, useSearchParams } from '@/next/navigation'
import { activateMember } from '@/service/common'
import { systemFeaturesQueryOptions } from '@/service/system-features'
import { useInvitationCheck } from '@/service/use-common'
import { timezones } from '@/utils/timezone'
import { getBrowserTimezone, timezones } from '@/utils/timezone'
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
type LanguageSelectOption = {
@ -43,15 +44,23 @@ const TIMEZONE_OPTIONS: TimezoneSelectOption[] = timezones.map(item => ({
name: item.name,
}))
const getInitialLanguage = (locale: Locale): Locale => {
if (LANGUAGE_OPTIONS.some(item => item.value === locale))
return locale
return i18n.defaultLocale
}
export default function InviteSettingsPage() {
const { t } = useTranslation()
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
const router = useRouter()
const searchParams = useSearchParams()
const token = decodeURIComponent(searchParams.get('invite_token') as string)
const locale = useLocale()
const [name, setName] = useState('')
const [language, setLanguage] = useState(LanguagesSupported[0])
const [timezone, setTimezone] = useState(() => Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles')
const [language, setLanguage] = useState(() => getInitialLanguage(locale))
const [timezone, setTimezone] = useState(() => getBrowserTimezone() || 'America/Los_Angeles')
const selectedLanguage = LANGUAGE_OPTIONS.find(item => item.value === language)
const selectedTimezone = TIMEZONE_OPTIONS.find(item => item.value === timezone)

View File

@ -0,0 +1,85 @@
import type { MockedFunction } from 'vitest'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { getBrowserTimezone } from '@/utils/timezone'
import ChangePasswordForm from '../page'
vi.mock('@/context/i18n', () => ({
useLocale: vi.fn(),
}))
vi.mock('@/next/navigation', () => ({
useRouter: vi.fn(),
useSearchParams: vi.fn(),
}))
vi.mock('@/service/use-common', () => ({
useMailRegister: vi.fn(),
}))
vi.mock('@/utils/timezone', () => ({
getBrowserTimezone: vi.fn(),
}))
vi.mock('@/utils/gtag', () => ({
sendGAEvent: vi.fn(),
}))
vi.mock('@/app/components/base/amplitude', () => ({
trackEvent: vi.fn(),
}))
vi.mock('@/utils/create-app-tracking', () => ({
rememberCreateAppExternalAttribution: vi.fn(),
}))
const mockRegister = vi.fn()
const mockReplace = vi.fn()
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
const mockUseMailRegister = useMailRegister as unknown as MockedFunction<typeof useMailRegister>
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
describe('Signup Set Password Page', () => {
beforeEach(() => {
vi.clearAllMocks()
mockUseLocale.mockReturnValue('zh-Hans')
mockUseSearchParams.mockReturnValue(new URLSearchParams('token=register-token') as unknown as ReturnType<typeof useSearchParams>)
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
mockUseMailRegister.mockReturnValue({
mutateAsync: mockRegister,
isPending: false,
} as unknown as ReturnType<typeof useMailRegister>)
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
mockRegister.mockResolvedValue({ result: 'fail', data: {} })
})
describe('Registration payload', () => {
it('should submit locale and browser timezone when setting password', async () => {
render(<ChangePasswordForm />)
fireEvent.change(screen.getByLabelText('common.account.newPassword'), {
target: { value: 'ValidPass123!' },
})
fireEvent.change(screen.getByLabelText('common.account.confirmPassword'), {
target: { value: 'ValidPass123!' },
})
fireEvent.click(screen.getByRole('button', { name: 'login.changePasswordBtn' }))
await waitFor(() => {
expect(mockRegister).toHaveBeenCalledWith({
token: 'register-token',
new_password: 'ValidPass123!',
password_confirm: 'ValidPass123!',
language: 'zh-Hans',
timezone: 'Asia/Shanghai',
})
})
})
})
})

View File

@ -9,10 +9,12 @@ import { useTranslation } from 'react-i18next'
import { trackEvent } from '@/app/components/base/amplitude'
import Input from '@/app/components/base/input'
import { validPassword } from '@/config'
import { useLocale } from '@/context/i18n'
import { useRouter, useSearchParams } from '@/next/navigation'
import { useMailRegister } from '@/service/use-common'
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
import { sendGAEvent } from '@/utils/gtag'
import { getBrowserTimezone } from '@/utils/timezone'
const parseUtmInfo = () => {
const utmInfoStr = Cookies.get('utm_info')
@ -32,6 +34,7 @@ const ChangePasswordForm = () => {
const router = useRouter()
const searchParams = useSearchParams()
const token = decodeURIComponent(searchParams.get('token') || '')
const locale = useLocale()
const [password, setPassword] = useState('')
const [confirmPassword, setConfirmPassword] = useState('')
@ -65,6 +68,8 @@ const ChangePasswordForm = () => {
token,
new_password: password,
password_confirm: confirmPassword,
language: locale,
timezone: getBrowserTimezone(),
})
const { result } = res as MailRegisterResponse
if (result === 'success') {
@ -88,7 +93,7 @@ const ChangePasswordForm = () => {
catch (error) {
console.error(error)
}
}, [password, token, valid, confirmPassword, register])
}, [password, token, valid, confirmPassword, register, locale])
return (
<div className={

View File

@ -339,7 +339,13 @@ export const uploadRemoteFileInfo = (url: string, isPublic?: boolean, silent?: b
export const sendEMailLoginCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string }> =>
post<CommonResponse & { data: string }>('/email-code-login', { body: { email, language } })
export const emailLoginWithCode = (data: { email: string, code: string, token: string, language: string }): Promise<LoginResponse> =>
export const emailLoginWithCode = (data: {
email: string
code: string
token: string
language: string
timezone?: string
}): Promise<LoginResponse> =>
post<LoginResponse>('/email-code-login/validity', { body: data })
export const sendResetPasswordCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string, message?: string, code?: string }> =>

View File

@ -178,7 +178,13 @@ export type MailRegisterResponse = { result: string, data: {} }
export const useMailRegister = () => {
return useMutation({
mutationKey: [NAME_SPACE, 'mail-register'],
mutationFn: (body: { token: string, new_password: string, password_confirm: string }) => {
mutationFn: (body: {
token: string
new_password: string
password_confirm: string
language?: string
timezone?: string
}) => {
return post<MailRegisterResponse>('/email-register', { body })
},
})

View File

@ -5,3 +5,10 @@ type Item = {
name: string
}
export const timezones: Item[] = tz
export const getBrowserTimezone = () => {
if (typeof Intl === 'undefined')
return undefined
return Intl.DateTimeFormat().resolvedOptions().timeZone || undefined
}