mirror of
https://github.com/langgenius/dify.git
synced 2026-05-15 06:27:58 +08:00
Compare commits
8 Commits
copilot/re
...
codex/init
| Author | SHA1 | Date | |
|---|---|---|---|
| 540a510d8e | |||
| 0836071203 | |||
| 454637060d | |||
| f086dbb9a8 | |||
| 926e3f8b29 | |||
| 9dc32f2318 | |||
| 356c1a21b9 | |||
| 6880c621ec |
@ -557,7 +557,7 @@ MAX_VARIABLE_SIZE=204800
|
||||
|
||||
# GraphEngine Worker Pool Configuration
|
||||
# Minimum number of workers per GraphEngine instance (default: 1)
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
# Maximum number of workers per GraphEngine instance (default: 10)
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
# Queue depth threshold that triggers worker scale up (default: 3)
|
||||
|
||||
@ -761,7 +761,7 @@ class WorkflowConfig(BaseSettings):
|
||||
# GraphEngine Worker Pool Configuration
|
||||
GRAPH_ENGINE_MIN_WORKERS: PositiveInt = Field(
|
||||
description="Minimum number of workers per GraphEngine instance",
|
||||
default=1,
|
||||
default=3,
|
||||
)
|
||||
|
||||
GRAPH_ENGINE_MAX_WORKERS: PositiveInt = Field(
|
||||
|
||||
@ -3,7 +3,7 @@ from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import languages
|
||||
from constants.languages import get_valid_language, languages
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.auth.error import (
|
||||
@ -15,6 +15,7 @@ from controllers.console.auth.error import (
|
||||
PasswordMismatchError,
|
||||
)
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.password import valid_password
|
||||
from models import Account
|
||||
from services.account_service import AccountService
|
||||
@ -40,12 +41,19 @@ class EmailRegisterResetPayload(BaseModel):
|
||||
token: str = Field(...)
|
||||
new_password: str = Field(...)
|
||||
password_confirm: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("new_password", "password_confirm")
|
||||
@classmethod
|
||||
def validate_password(cls, value: str) -> str:
|
||||
return valid_password(value)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
return validate_timezone_string(value) if value else value
|
||||
|
||||
|
||||
register_schema_models(console_ns, EmailRegisterSendPayload, EmailRegisterValidityPayload, EmailRegisterResetPayload)
|
||||
|
||||
@ -145,7 +153,15 @@ class EmailRegisterResetApi(Resource):
|
||||
if account:
|
||||
raise EmailAlreadyInUseError()
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if args.timezone or args.language:
|
||||
account = self._create_new_account(
|
||||
normalized_email,
|
||||
args.password_confirm,
|
||||
args.timezone,
|
||||
args.language,
|
||||
)
|
||||
else:
|
||||
account = self._create_new_account(normalized_email, args.password_confirm)
|
||||
if not account:
|
||||
raise AccountNotFoundError()
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
@ -153,7 +169,13 @@ class EmailRegisterResetApi(Resource):
|
||||
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
def _create_new_account(self, email: str, password: str) -> Account | None:
|
||||
def _create_new_account(
|
||||
self,
|
||||
email: str,
|
||||
password: str,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> Account | None:
|
||||
# Create new account if allowed
|
||||
account = None
|
||||
try:
|
||||
@ -161,7 +183,8 @@ class EmailRegisterResetApi(Resource):
|
||||
email=email,
|
||||
name=email,
|
||||
password=password,
|
||||
interface_language=languages[0],
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=timezone,
|
||||
)
|
||||
except AccountRegisterError:
|
||||
raise AccountInFreezeError()
|
||||
|
||||
@ -3,7 +3,7 @@ import logging
|
||||
import flask_login
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
@ -34,6 +34,7 @@ from controllers.console.wraps import (
|
||||
)
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import EmailStr, extract_remote_ip
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
@ -69,6 +70,12 @@ class EmailCodeLoginPayload(BaseModel):
|
||||
code: str = Field(...)
|
||||
token: str = Field(...)
|
||||
language: str | None = Field(default=None)
|
||||
timezone: str | None = Field(default=None)
|
||||
|
||||
@field_validator("timezone")
|
||||
@classmethod
|
||||
def validate_timezone(cls, value: str | None) -> str | None:
|
||||
return validate_timezone_string(value) if value else value
|
||||
|
||||
|
||||
register_schema_models(console_ns, LoginPayload, EmailPayload, EmailCodeLoginPayload)
|
||||
@ -288,6 +295,7 @@ class EmailCodeLoginApi(Resource):
|
||||
email=user_email,
|
||||
name=user_email,
|
||||
interface_language=get_valid_language(language),
|
||||
timezone=args.timezone,
|
||||
)
|
||||
except WorkSpaceNotAllowedCreateError:
|
||||
raise NotAllowedCreateWorkspace()
|
||||
|
||||
@ -12,7 +12,8 @@ from events.tenant_event import tenant_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.helper import timezone as validate_timezone_string
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
@ -53,6 +54,31 @@ def get_oauth_providers():
|
||||
return OAUTH_PROVIDERS
|
||||
|
||||
|
||||
def _validated_timezone(value: str | None) -> str | None:
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return validate_timezone_string(value)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def _validated_language(value: str | None) -> str | None:
|
||||
if value and value in languages:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
def _preferred_interface_language(language: str | None = None) -> str:
|
||||
if language:
|
||||
return language
|
||||
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
return preferred_lang
|
||||
return languages[0]
|
||||
|
||||
|
||||
@console_ns.route("/oauth/login/<provider>")
|
||||
class OAuthLogin(Resource):
|
||||
@console_ns.doc("oauth_login")
|
||||
@ -64,13 +90,19 @@ class OAuthLogin(Resource):
|
||||
@console_ns.response(400, "Invalid provider")
|
||||
def get(self, provider: str):
|
||||
invite_token = request.args.get("invite_token") or None
|
||||
timezone = _validated_timezone(request.args.get("timezone") or None)
|
||||
language = _validated_language(request.args.get("language") or None)
|
||||
OAUTH_PROVIDERS = get_oauth_providers()
|
||||
with current_app.app_context():
|
||||
oauth_provider = OAUTH_PROVIDERS.get(provider)
|
||||
if not oauth_provider:
|
||||
return {"error": "Invalid provider"}, 400
|
||||
|
||||
auth_url = oauth_provider.get_authorization_url(invite_token=invite_token)
|
||||
auth_url = oauth_provider.get_authorization_url(
|
||||
invite_token=invite_token,
|
||||
timezone=timezone,
|
||||
language=language,
|
||||
)
|
||||
return redirect(auth_url)
|
||||
|
||||
|
||||
@ -96,9 +128,10 @@ class OAuthCallback(Resource):
|
||||
|
||||
code = request.args.get("code")
|
||||
state = request.args.get("state")
|
||||
invite_token = None
|
||||
if state:
|
||||
invite_token = state
|
||||
oauth_state = decode_oauth_state(state)
|
||||
invite_token = oauth_state.get("invite_token")
|
||||
timezone = _validated_timezone(oauth_state.get("timezone"))
|
||||
language = _validated_language(oauth_state.get("language"))
|
||||
|
||||
if not code:
|
||||
return {"error": "Authorization code is required"}, 400
|
||||
@ -129,7 +162,7 @@ class OAuthCallback(Resource):
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
|
||||
|
||||
try:
|
||||
account, oauth_new_user = _generate_account(provider, user_info)
|
||||
account, oauth_new_user = _generate_account(provider, user_info, timezone=timezone, language=language)
|
||||
except AccountNotFoundError:
|
||||
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Account not found.")
|
||||
except (WorkSpaceNotFoundError, WorkSpaceNotAllowedCreateError):
|
||||
@ -184,7 +217,12 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
|
||||
return account
|
||||
|
||||
|
||||
def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account, bool]:
|
||||
def _generate_account(
|
||||
provider: str,
|
||||
user_info: OAuthUserInfo,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> tuple[Account, bool]:
|
||||
# Get account by openid or email.
|
||||
account = _get_account_by_openid_or_email(provider, user_info)
|
||||
oauth_new_user = False
|
||||
@ -211,25 +249,28 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
|
||||
"30 days and is temporarily unavailable for new account registration"
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
raise AccountRegisterError(description=("Invalid email or password"))
|
||||
account_name = user_info.name or "Dify"
|
||||
account = RegisterService.register(
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
# Set interface language
|
||||
preferred_lang = request.accept_languages.best_match(languages)
|
||||
if preferred_lang and preferred_lang in languages:
|
||||
interface_language = preferred_lang
|
||||
interface_language = _preferred_interface_language(language)
|
||||
if timezone:
|
||||
account = RegisterService.register(
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
timezone=timezone,
|
||||
)
|
||||
else:
|
||||
interface_language = languages[0]
|
||||
account.interface_language = interface_language
|
||||
db.session.commit()
|
||||
account = RegisterService.register(
|
||||
email=normalized_email,
|
||||
name=account_name,
|
||||
password=None,
|
||||
open_id=user_info.id,
|
||||
provider=provider,
|
||||
language=interface_language,
|
||||
)
|
||||
|
||||
# Link account
|
||||
AccountService.link_account_integrate(provider, user_info.id, account)
|
||||
|
||||
@ -1,3 +1,6 @@
|
||||
import base64
|
||||
import binascii
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
@ -27,6 +30,12 @@ class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class OAuthState(TypedDict, total=False):
|
||||
invite_token: str
|
||||
timezone: str
|
||||
language: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
@ -46,6 +55,7 @@ class GoogleRawUserInfo(TypedDict):
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
OAUTH_STATE_ADAPTER = TypeAdapter(OAuthState)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
@ -58,6 +68,37 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def encode_oauth_state(
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str | None:
|
||||
state: OAuthState = {}
|
||||
if invite_token:
|
||||
state["invite_token"] = invite_token
|
||||
if timezone:
|
||||
state["timezone"] = timezone
|
||||
if language:
|
||||
state["language"] = language
|
||||
if not state:
|
||||
return None
|
||||
|
||||
raw_state = json.dumps(state, separators=(",", ":")).encode("utf-8")
|
||||
return base64.urlsafe_b64encode(raw_state).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def decode_oauth_state(state: str | None) -> OAuthState:
|
||||
if not state:
|
||||
return {}
|
||||
|
||||
try:
|
||||
padded_state = state + "=" * (-len(state) % 4)
|
||||
raw_state = base64.urlsafe_b64decode(padded_state.encode("ascii")).decode("utf-8")
|
||||
return OAUTH_STATE_ADAPTER.validate_python(json.loads(raw_state))
|
||||
except (binascii.Error, ValueError, UnicodeDecodeError, json.JSONDecodeError, ValidationError):
|
||||
return {"invite_token": state}
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
@ -76,7 +117,12 @@ class OAuth:
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
@ -99,14 +145,20 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "user:email", # Request only basic user information
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
|
||||
if state:
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
@ -186,15 +238,21 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
def get_authorization_url(
|
||||
self,
|
||||
invite_token: str | None = None,
|
||||
timezone: str | None = None,
|
||||
language: str | None = None,
|
||||
) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"scope": "openid email",
|
||||
}
|
||||
if invite_token:
|
||||
params["state"] = invite_token
|
||||
state = encode_oauth_state(invite_token=invite_token, timezone=timezone, language=language)
|
||||
if state:
|
||||
params["state"] = state
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str) -> str:
|
||||
|
||||
@ -11596,6 +11596,7 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
| code | string | | Yes |
|
||||
| email | string | | Yes |
|
||||
| language | | | No |
|
||||
| timezone | | | No |
|
||||
| token | string | | Yes |
|
||||
|
||||
#### EmailPayload
|
||||
@ -11609,8 +11610,10 @@ Request payload for bulk downloading documents as a zip archive.
|
||||
|
||||
| Name | Type | Description | Required |
|
||||
| ---- | ---- | ----------- | -------- |
|
||||
| language | | | No |
|
||||
| new_password | string | | Yes |
|
||||
| password_confirm | string | | Yes |
|
||||
| timezone | | | No |
|
||||
| token | string | | Yes |
|
||||
|
||||
#### EmailRegisterSendPayload
|
||||
|
||||
@ -29,6 +29,7 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client, redis_fallback
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import RateLimiter, TokenManager
|
||||
from libs.helper import timezone as validate_timezone
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@ -271,8 +272,9 @@ class AccountService:
|
||||
password: str | None = None,
|
||||
interface_theme: str = "light",
|
||||
is_setup: bool | None = False,
|
||||
timezone: str | None = None,
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
"""Create an account, preferring explicit user timezone over language-derived defaults."""
|
||||
if not FeatureService.get_system_features().is_allow_register and not is_setup:
|
||||
from controllers.console.error import AccountNotFound
|
||||
|
||||
@ -309,7 +311,9 @@ class AccountService:
|
||||
password_salt=salt_to_set,
|
||||
interface_language=interface_language,
|
||||
interface_theme=interface_theme,
|
||||
timezone=language_timezone_mapping.get(interface_language, "UTC"),
|
||||
timezone=(
|
||||
validate_timezone(timezone) if timezone else language_timezone_mapping.get(interface_language, "UTC")
|
||||
),
|
||||
)
|
||||
|
||||
db.session.add(account)
|
||||
@ -318,12 +322,24 @@ class AccountService:
|
||||
|
||||
@staticmethod
|
||||
def create_account_and_tenant(
|
||||
email: str, name: str, interface_language: str, password: str | None = None
|
||||
email: str, name: str, interface_language: str, password: str | None = None, timezone: str | None = None
|
||||
) -> Account:
|
||||
"""create account"""
|
||||
account = AccountService.create_account(
|
||||
email=email, name=name, interface_language=interface_language, password=password
|
||||
)
|
||||
if timezone:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
timezone=timezone,
|
||||
)
|
||||
else:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
)
|
||||
|
||||
try:
|
||||
TenantService.create_owner_tenant_if_not_exist(account=account)
|
||||
@ -1483,17 +1499,29 @@ class RegisterService:
|
||||
status: AccountStatus | None = None,
|
||||
is_setup: bool | None = False,
|
||||
create_workspace_required: bool | None = True,
|
||||
timezone: str | None = None,
|
||||
) -> Account:
|
||||
db.session.begin_nested()
|
||||
"""Register account"""
|
||||
try:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=get_valid_language(language),
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
)
|
||||
interface_language = get_valid_language(language)
|
||||
if timezone:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
timezone=timezone,
|
||||
)
|
||||
else:
|
||||
account = AccountService.create_account(
|
||||
email=email,
|
||||
name=name,
|
||||
interface_language=interface_language,
|
||||
password=password,
|
||||
is_setup=is_setup,
|
||||
)
|
||||
account.status = status or AccountStatus.ACTIVE
|
||||
account.initialized_at = naive_utc_now()
|
||||
|
||||
|
||||
@ -148,6 +148,102 @@ class TestEmailRegisterResetApi:
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_passes_timezone_to_new_account(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!", "Asia/Shanghai", None)
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.reset_login_error_rate_limit")
|
||||
@patch("controllers.console.auth.email_register.AccountService.login")
|
||||
@patch("controllers.console.auth.email_register.EmailRegisterResetApi._create_new_account")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.auth.email_register.AccountService.revoke_email_register_token")
|
||||
@patch("controllers.console.auth.email_register.AccountService.get_email_register_data")
|
||||
@patch("controllers.console.auth.email_register.extract_remote_ip", return_value="127.0.0.1")
|
||||
def test_reset_passes_language_to_new_account(
|
||||
self,
|
||||
mock_extract_ip,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_get_account,
|
||||
mock_create_account,
|
||||
mock_login,
|
||||
mock_reset_login_rate,
|
||||
app: Flask,
|
||||
):
|
||||
mock_get_data.return_value = {"phase": "register", "email": "Invitee@Example.com"}
|
||||
mock_create_account.return_value = MagicMock()
|
||||
token_pair = MagicMock()
|
||||
token_pair.model_dump.return_value = {"access_token": "a", "refresh_token": "r"}
|
||||
mock_login.return_value = token_pair
|
||||
mock_get_account.return_value = None
|
||||
|
||||
feature_flags = SimpleNamespace(enable_email_password_login=True, is_allow_register=True)
|
||||
with (
|
||||
patch("controllers.console.wraps.dify_config", SimpleNamespace(EDITION="CLOUD")),
|
||||
patch("controllers.console.wraps.FeatureService.get_system_features", return_value=feature_flags),
|
||||
):
|
||||
with app.test_request_context(
|
||||
"/email-register",
|
||||
method="POST",
|
||||
json={
|
||||
"token": "token-123",
|
||||
"new_password": "ValidPass123!",
|
||||
"password_confirm": "ValidPass123!",
|
||||
"language": "zh-Hans",
|
||||
},
|
||||
):
|
||||
response = EmailRegisterResetApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": {"access_token": "a", "refresh_token": "r"}}
|
||||
mock_create_account.assert_called_once_with("invitee@example.com", "ValidPass123!", None, "zh-Hans")
|
||||
mock_reset_login_rate.assert_called_once_with("invitee@example.com")
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_extract_ip.assert_called_once()
|
||||
|
||||
|
||||
def test_get_account_by_email_with_case_fallback_falls_back_to_lowercase():
|
||||
"""Test that case fallback tries lowercase when exact match fails."""
|
||||
|
||||
@ -14,7 +14,7 @@ from controllers.console.auth.oauth import (
|
||||
_get_account_by_openid_or_email,
|
||||
get_oauth_providers,
|
||||
)
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from libs.oauth import OAuthUserInfo, encode_oauth_state
|
||||
from models.account import AccountStatus
|
||||
from services.account_service import AccountService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -101,7 +101,55 @@ class TestOAuthLogin:
|
||||
with app.test_request_context(f"/auth/oauth/github?{query_string}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(invite_token=expected_token)
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=expected_token,
|
||||
timezone=None,
|
||||
language=None,
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_pass_timezone_to_oauth_state(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app: Flask,
|
||||
mock_oauth_provider,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github?timezone=Asia/Shanghai"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone="Asia/Shanghai",
|
||||
language=None,
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
def test_should_pass_language_to_oauth_state(
|
||||
self,
|
||||
mock_redirect,
|
||||
mock_get_providers,
|
||||
resource,
|
||||
app: Flask,
|
||||
mock_oauth_provider,
|
||||
):
|
||||
mock_get_providers.return_value = {"github": mock_oauth_provider, "google": None}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github?language=zh-Hans"):
|
||||
resource.get("github")
|
||||
|
||||
mock_oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone=None,
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?...")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@ -229,7 +277,8 @@ class TestOAuthCallback:
|
||||
mock_register_service.is_valid_invite_token.return_value = True
|
||||
mock_register_service.get_invitation_by_token.return_value = {"email": "user@example.com"}
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code&state=invite123"):
|
||||
state = encode_oauth_state(invite_token="invite123", timezone="Asia/Shanghai")
|
||||
with app.test_request_context(f"/auth/oauth/github/callback?code=test_code&state={state}"):
|
||||
resource.get("github")
|
||||
|
||||
mock_register_service.get_invitation_by_token.assert_called_once_with(token="invite123")
|
||||
@ -488,7 +537,12 @@ class TestAccountGeneration:
|
||||
|
||||
if should_create:
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="en-US",
|
||||
)
|
||||
else:
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -515,7 +569,73 @@ class TestAccountGeneration:
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="upper@example.com", name="Test User", password=None, open_id="123", provider="github"
|
||||
email="upper@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="en-US",
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_browser_timezone(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app: Flask,
|
||||
user_info,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
|
||||
_generate_account("github", user_info, timezone="Asia/Shanghai")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.AccountService")
|
||||
@patch("controllers.console.auth.oauth.TenantService")
|
||||
def test_should_register_with_state_language(
|
||||
self,
|
||||
mock_tenant_service,
|
||||
mock_account_service,
|
||||
mock_register_service,
|
||||
mock_feature_service,
|
||||
mock_get_account,
|
||||
app: Flask,
|
||||
user_info,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
mock_register_service.register.return_value = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
_generate_account("github", user_info, language="zh-Hans")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
)
|
||||
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email")
|
||||
|
||||
@ -0,0 +1,25 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from controllers.console.auth.email_register import EmailRegisterResetApi
|
||||
|
||||
|
||||
@patch("controllers.console.auth.email_register.AccountService.create_account_and_tenant")
|
||||
def test_create_new_account_uses_requested_language(mock_create_account):
|
||||
account = MagicMock()
|
||||
mock_create_account.return_value = account
|
||||
|
||||
result = EmailRegisterResetApi()._create_new_account(
|
||||
"invitee@example.com",
|
||||
"ValidPass123!",
|
||||
timezone="Asia/Shanghai",
|
||||
language="zh-Hans",
|
||||
)
|
||||
|
||||
assert result is account
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="invitee@example.com",
|
||||
name="invitee@example.com",
|
||||
password="ValidPass123!",
|
||||
interface_language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
@ -342,6 +342,7 @@ class TestEmailCodeLoginApi:
|
||||
"code": encode_code("123456"),
|
||||
"token": "valid_token",
|
||||
"language": "en-US",
|
||||
"timezone": "Asia/Shanghai",
|
||||
},
|
||||
):
|
||||
api = EmailCodeLoginApi()
|
||||
@ -349,7 +350,12 @@ class TestEmailCodeLoginApi:
|
||||
|
||||
# Assert
|
||||
assert response.json["result"] == "success"
|
||||
mock_create_account.assert_called_once()
|
||||
mock_create_account.assert_called_once_with(
|
||||
email="newuser@example.com",
|
||||
name="newuser@example.com",
|
||||
interface_language="en-US",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.auth.login.AccountService.get_email_code_login_data")
|
||||
|
||||
@ -0,0 +1,122 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.console.auth.oauth import OAuthLogin, _generate_account
|
||||
from libs.oauth import OAuthUserInfo
|
||||
from services.errors.account import AccountRegisterError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.redirect")
|
||||
@patch("controllers.console.auth.oauth.get_oauth_providers")
|
||||
def test_oauth_login_passes_language_and_timezone_to_authorization_url(
|
||||
mock_get_oauth_providers,
|
||||
mock_redirect,
|
||||
app: Flask,
|
||||
):
|
||||
oauth_provider = MagicMock()
|
||||
oauth_provider.get_authorization_url.return_value = "https://github.com/login/oauth/authorize?state=..."
|
||||
mock_get_oauth_providers.return_value = {"github": oauth_provider}
|
||||
|
||||
with app.test_request_context("/oauth/login/github?language=zh-Hans&timezone=Asia/Shanghai"):
|
||||
OAuthLogin().get("github")
|
||||
|
||||
oauth_provider.get_authorization_url.assert_called_once_with(
|
||||
invite_token=None,
|
||||
timezone="Asia/Shanghai",
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_redirect.assert_called_once_with("https://github.com/login/oauth/authorize?state=...")
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_registers_with_browser_timezone(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_link_account,
|
||||
app: Flask,
|
||||
):
|
||||
account = MagicMock()
|
||||
mock_register_service.register.return_value = account
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "zh-Hans,zh;q=0.9"}):
|
||||
result, oauth_new_user = _generate_account("github", user_info, timezone="Asia/Shanghai")
|
||||
|
||||
assert result is account
|
||||
assert oauth_new_user is True
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="user@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="github-123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.AccountService.link_account_integrate")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_prefers_state_language_over_accept_language(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_link_account,
|
||||
app: Flask,
|
||||
):
|
||||
account = MagicMock()
|
||||
mock_register_service.register.return_value = account
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="User@Example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
_generate_account("github", user_info, language="zh-Hans")
|
||||
|
||||
mock_register_service.register.assert_called_once_with(
|
||||
email="user@example.com",
|
||||
name="Test User",
|
||||
password=None,
|
||||
open_id="github-123",
|
||||
provider="github",
|
||||
language="zh-Hans",
|
||||
)
|
||||
mock_link_account.assert_called_once_with("github", "github-123", account)
|
||||
|
||||
|
||||
@patch("controllers.console.auth.oauth.dify_config")
|
||||
@patch("controllers.console.auth.oauth.RegisterService")
|
||||
@patch("controllers.console.auth.oauth.FeatureService")
|
||||
@patch("controllers.console.auth.oauth._get_account_by_openid_or_email", return_value=None)
|
||||
def test_generate_account_rejects_new_user_when_registration_disabled(
|
||||
mock_get_account,
|
||||
mock_feature_service,
|
||||
mock_register_service,
|
||||
mock_config,
|
||||
app: Flask,
|
||||
):
|
||||
mock_feature_service.get_system_features.return_value.is_allow_register = False
|
||||
mock_config.BILLING_ENABLED = False
|
||||
user_info = OAuthUserInfo(id="github-123", name="Test User", email="user@example.com")
|
||||
|
||||
with app.test_request_context(headers={"Accept-Language": "en-US,en;q=0.9"}):
|
||||
with pytest.raises(AccountRegisterError):
|
||||
_generate_account("github", user_info)
|
||||
|
||||
mock_register_service.register.assert_not_called()
|
||||
@ -297,7 +297,7 @@ class TableTestRunner:
|
||||
max_workers: int = 4,
|
||||
enable_logging: bool = False,
|
||||
log_level: str = "INFO",
|
||||
graph_engine_min_workers: int = 1,
|
||||
graph_engine_min_workers: int = 3,
|
||||
graph_engine_max_workers: int = 1,
|
||||
graph_engine_scale_up_threshold: int = 5,
|
||||
graph_engine_scale_down_idle_time: float = 30.0,
|
||||
@ -310,7 +310,7 @@ class TableTestRunner:
|
||||
max_workers: Maximum number of parallel workers for test execution
|
||||
enable_logging: Enable detailed logging
|
||||
log_level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
graph_engine_min_workers: Minimum workers for GraphEngine (default: 1)
|
||||
graph_engine_min_workers: Minimum workers for GraphEngine (default: 3)
|
||||
graph_engine_max_workers: Maximum workers for GraphEngine (default: 1)
|
||||
graph_engine_scale_up_threshold: Queue depth to trigger scale up
|
||||
graph_engine_scale_down_idle_time: Idle time before scaling down
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from libs.oauth import OAuth
|
||||
from libs.oauth import OAuth, decode_oauth_state, encode_oauth_state
|
||||
|
||||
|
||||
def test_oauth_base_methods_raise_not_implemented():
|
||||
@ -17,3 +17,17 @@ def test_oauth_base_methods_raise_not_implemented():
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
oauth._transform_user_info({})
|
||||
|
||||
|
||||
def test_oauth_state_round_trips_invite_token_timezone_and_language():
|
||||
state = encode_oauth_state(invite_token="invite-123", timezone="Asia/Shanghai", language="zh-Hans")
|
||||
|
||||
assert decode_oauth_state(state) == {
|
||||
"invite_token": "invite-123",
|
||||
"timezone": "Asia/Shanghai",
|
||||
"language": "zh-Hans",
|
||||
}
|
||||
|
||||
|
||||
def test_oauth_state_keeps_legacy_raw_invite_token_compatible():
|
||||
assert decode_oauth_state("legacy-invite-token") == {"invite_token": "legacy-invite-token"}
|
||||
|
||||
@ -4,7 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo, decode_oauth_state
|
||||
|
||||
|
||||
class BaseOAuthTest:
|
||||
@ -37,15 +37,25 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
return GitHubOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
("invite_token", "timezone", "language", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
(None, None, None, None),
|
||||
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
|
||||
("", None, None, None),
|
||||
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
|
||||
(None, None, "zh-Hans", {"language": "zh-Hans"}),
|
||||
(
|
||||
"test_invite_token",
|
||||
"Asia/Shanghai",
|
||||
"zh-Hans",
|
||||
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
def test_should_generate_authorization_url_correctly(
|
||||
self, oauth, oauth_config, invite_token, timezone, language, expected_state
|
||||
):
|
||||
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
@ -56,7 +66,7 @@ class TestGitHubOAuth(BaseOAuthTest):
|
||||
assert params["scope"][0] == "user:email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
assert decode_oauth_state(params["state"][0]) == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
@ -208,15 +218,25 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
return GoogleOAuth(oauth_config["client_id"], oauth_config["client_secret"], oauth_config["redirect_uri"])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("invite_token", "expected_state"),
|
||||
("invite_token", "timezone", "language", "expected_state"),
|
||||
[
|
||||
(None, None),
|
||||
("test_invite_token", "test_invite_token"),
|
||||
("", None),
|
||||
(None, None, None, None),
|
||||
("test_invite_token", None, None, {"invite_token": "test_invite_token"}),
|
||||
("", None, None, None),
|
||||
(None, "Asia/Shanghai", None, {"timezone": "Asia/Shanghai"}),
|
||||
(None, None, "zh-Hans", {"language": "zh-Hans"}),
|
||||
(
|
||||
"test_invite_token",
|
||||
"Asia/Shanghai",
|
||||
"zh-Hans",
|
||||
{"invite_token": "test_invite_token", "timezone": "Asia/Shanghai", "language": "zh-Hans"},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_should_generate_authorization_url_correctly(self, oauth, oauth_config, invite_token, expected_state):
|
||||
url = oauth.get_authorization_url(invite_token)
|
||||
def test_should_generate_authorization_url_correctly(
|
||||
self, oauth, oauth_config, invite_token, timezone, language, expected_state
|
||||
):
|
||||
url = oauth.get_authorization_url(invite_token, timezone=timezone, language=language)
|
||||
parsed, params = self.parse_auth_url(url)
|
||||
|
||||
assert parsed.scheme == "https"
|
||||
@ -228,7 +248,7 @@ class TestGoogleOAuth(BaseOAuthTest):
|
||||
assert params["scope"][0] == "openid email"
|
||||
|
||||
if expected_state:
|
||||
assert params["state"][0] == expected_state
|
||||
assert decode_oauth_state(params["state"][0]) == expected_state
|
||||
else:
|
||||
assert "state" not in params
|
||||
|
||||
|
||||
@ -260,7 +260,7 @@ class TestAccountService:
|
||||
assert result.interface_theme == "light"
|
||||
assert result.password is not None
|
||||
assert result.password_salt is not None
|
||||
assert result.timezone is not None
|
||||
assert result.timezone == "America/New_York"
|
||||
|
||||
# Verify database operations
|
||||
mock_db_dependencies["db"].session.add.assert_called_once()
|
||||
@ -271,7 +271,28 @@ class TestAccountService:
|
||||
assert added_account.interface_theme == "light"
|
||||
assert added_account.password is not None
|
||||
assert added_account.password_salt is not None
|
||||
assert added_account.timezone is not None
|
||||
assert added_account.timezone == "America/New_York"
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_create_account_uses_explicit_timezone(
|
||||
self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies
|
||||
):
|
||||
"""Test account creation prefers explicit browser timezone."""
|
||||
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
|
||||
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
|
||||
mock_password_dependencies["hash_password"].return_value = b"hashed_password"
|
||||
|
||||
result = AccountService.create_account(
|
||||
email="test@example.com",
|
||||
name="Test User",
|
||||
interface_language="en-US",
|
||||
password="password123",
|
||||
timezone="Asia/Shanghai",
|
||||
)
|
||||
|
||||
assert result.timezone == "Asia/Shanghai"
|
||||
added_account = mock_db_dependencies["db"].session.add.call_args[0][0]
|
||||
assert added_account.timezone == "Asia/Shanghai"
|
||||
self._assert_database_operations_called(mock_db_dependencies["db"])
|
||||
|
||||
def test_create_account_registration_disabled(self, mock_external_service_dependencies):
|
||||
|
||||
@ -177,7 +177,7 @@ WORKFLOW_MAX_EXECUTION_TIME=1200
|
||||
WORKFLOW_CALL_MAX_DEPTH=5
|
||||
MAX_VARIABLE_SIZE=204800
|
||||
WORKFLOW_FILE_UPLOAD_LIMIT=10
|
||||
GRAPH_ENGINE_MIN_WORKERS=1
|
||||
GRAPH_ENGINE_MIN_WORKERS=3
|
||||
GRAPH_ENGINE_MAX_WORKERS=10
|
||||
GRAPH_ENGINE_SCALE_UP_THRESHOLD=3
|
||||
GRAPH_ENGINE_SCALE_DOWN_IDLE_TIME=5.0
|
||||
|
||||
@ -13,6 +13,7 @@ import { useLocale } from '@/context/i18n'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { emailLoginWithCode, sendEMailLoginCode } from '@/service/common'
|
||||
import { encryptVerificationCode } from '@/utils/encryption'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
|
||||
|
||||
export default function CheckCode() {
|
||||
@ -39,7 +40,13 @@ export default function CheckCode() {
|
||||
return
|
||||
}
|
||||
setIsLoading(true)
|
||||
const ret = await emailLoginWithCode({ email, code: encryptVerificationCode(code), token, language })
|
||||
const ret = await emailLoginWithCode({
|
||||
email,
|
||||
code: encryptVerificationCode(code),
|
||||
token,
|
||||
language,
|
||||
timezone: getBrowserTimezone(),
|
||||
})
|
||||
if (ret.result === 'success') {
|
||||
// Track login success event
|
||||
trackEvent('user_login_success', {
|
||||
|
||||
86
web/app/signin/components/__tests__/social-auth.spec.tsx
Normal file
86
web/app/signin/components/__tests__/social-auth.spec.tsx
Normal file
@ -0,0 +1,86 @@
|
||||
import { render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useSearchParams } from '@/next/navigation'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
import SocialAuth from '../social-auth'
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useSearchParams: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/timezone', () => ({
|
||||
getBrowserTimezone: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockUseSearchParams = vi.mocked(useSearchParams)
|
||||
const mockUseLocale = vi.mocked(useLocale)
|
||||
const mockGetBrowserTimezone = vi.mocked(getBrowserTimezone)
|
||||
|
||||
describe('SocialAuth', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseSearchParams.mockReturnValue(new URLSearchParams() as unknown as ReturnType<typeof useSearchParams>)
|
||||
mockUseLocale.mockReturnValue('zh-Hans')
|
||||
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
|
||||
})
|
||||
|
||||
describe('Rendering', () => {
|
||||
it('should render oauth provider links', () => {
|
||||
render(<SocialAuth />)
|
||||
|
||||
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toBeInTheDocument()
|
||||
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OAuth params', () => {
|
||||
it('should include browser timezone and locale in oauth links', () => {
|
||||
render(<SocialAuth />)
|
||||
|
||||
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
|
||||
'href',
|
||||
expect.stringContaining('timezone=Asia%2FShanghai'),
|
||||
)
|
||||
expect(screen.getByRole('link', { name: 'login.withGitHub' })).toHaveAttribute(
|
||||
'href',
|
||||
expect.stringContaining('language=zh-Hans'),
|
||||
)
|
||||
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
|
||||
'href',
|
||||
expect.stringContaining('timezone=Asia%2FShanghai'),
|
||||
)
|
||||
expect(screen.getByRole('link', { name: 'login.withGoogle' })).toHaveAttribute(
|
||||
'href',
|
||||
expect.stringContaining('language=zh-Hans'),
|
||||
)
|
||||
})
|
||||
|
||||
it('should preserve invite token when adding timezone', () => {
|
||||
mockUseSearchParams.mockReturnValue(
|
||||
new URLSearchParams('invite_token=invite-123') as unknown as ReturnType<typeof useSearchParams>,
|
||||
)
|
||||
|
||||
render(<SocialAuth />)
|
||||
|
||||
const githubLink = screen.getByRole('link', { name: 'login.withGitHub' })
|
||||
expect(githubLink).toHaveAttribute('href', expect.stringContaining('invite_token=invite-123'))
|
||||
expect(githubLink).toHaveAttribute('href', expect.stringContaining('timezone=Asia%2FShanghai'))
|
||||
expect(githubLink).toHaveAttribute('href', expect.stringContaining('language=zh-Hans'))
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should omit timezone when browser timezone is unavailable', () => {
|
||||
mockGetBrowserTimezone.mockReturnValue(undefined)
|
||||
|
||||
render(<SocialAuth />)
|
||||
|
||||
expect(screen.getByRole('link', { name: 'login.withGitHub' }).getAttribute('href')).not.toContain('timezone=')
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -2,8 +2,10 @@ import { Button } from '@langgenius/dify-ui/button'
|
||||
import { cn } from '@langgenius/dify-ui/cn'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { API_PREFIX } from '@/config'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useSearchParams } from '@/next/navigation'
|
||||
import { getPurifyHref } from '@/utils'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
import style from '../page.module.css'
|
||||
|
||||
type SocialAuthProps = {
|
||||
@ -13,11 +15,19 @@ type SocialAuthProps = {
|
||||
export default function SocialAuth(props: SocialAuthProps) {
|
||||
const { t } = useTranslation()
|
||||
const searchParams = useSearchParams()
|
||||
const locale = useLocale()
|
||||
|
||||
const getOAuthLink = (href: string) => {
|
||||
const url = getPurifyHref(`${API_PREFIX}${href}`)
|
||||
if (searchParams.has('invite_token'))
|
||||
return `${url}?${searchParams.toString()}`
|
||||
const params = new URLSearchParams(searchParams.toString())
|
||||
const timezone = getBrowserTimezone()
|
||||
if (timezone)
|
||||
params.set('timezone', timezone)
|
||||
params.set('language', locale)
|
||||
|
||||
const query = params.toString()
|
||||
if (query)
|
||||
return `${url}?${query}`
|
||||
|
||||
return url
|
||||
}
|
||||
|
||||
139
web/app/signin/invite-settings/__tests__/page.spec.tsx
Normal file
139
web/app/signin/invite-settings/__tests__/page.spec.tsx
Normal file
@ -0,0 +1,139 @@
|
||||
import type { MockedFunction } from 'vitest'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { activateMember } from '@/service/common'
|
||||
import { useInvitationCheck } from '@/service/use-common'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
import InviteSettingsPage from '../page'
|
||||
|
||||
vi.mock('@tanstack/react-query', async () => {
|
||||
const actual = await vi.importActual<typeof import('@tanstack/react-query')>('@tanstack/react-query')
|
||||
return {
|
||||
...actual,
|
||||
useSuspenseQuery: vi.fn(() => ({
|
||||
data: {
|
||||
branding: {
|
||||
enabled: true,
|
||||
},
|
||||
},
|
||||
})),
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config', () => ({
|
||||
i18n: {
|
||||
defaultLocale: 'en-US',
|
||||
},
|
||||
setLocaleOnClient: vi.fn(() => Promise.resolve()),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
useSearchParams: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/common', () => ({
|
||||
activateMember: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useInvitationCheck: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/timezone', () => ({
|
||||
getBrowserTimezone: vi.fn(),
|
||||
timezones: [
|
||||
{ value: 'Asia/Shanghai', name: 'Asia/Shanghai' },
|
||||
{ value: 'America/Los_Angeles', name: 'America/Los_Angeles' },
|
||||
],
|
||||
}))
|
||||
|
||||
vi.mock('../utils/post-login-redirect', () => ({
|
||||
resolvePostLoginRedirect: vi.fn(() => null),
|
||||
}))
|
||||
|
||||
const mockReplace = vi.fn()
|
||||
const mockRefetch = vi.fn()
|
||||
|
||||
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
|
||||
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
|
||||
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
|
||||
const mockActivateMember = activateMember as unknown as MockedFunction<typeof activateMember>
|
||||
const mockUseInvitationCheck = useInvitationCheck as unknown as MockedFunction<typeof useInvitationCheck>
|
||||
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
|
||||
|
||||
describe('InviteSettingsPage', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseLocale.mockReturnValue('zh-Hans')
|
||||
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
|
||||
mockUseSearchParams.mockReturnValue(
|
||||
new URLSearchParams('invite_token=invite-token') as unknown as ReturnType<typeof useSearchParams>,
|
||||
)
|
||||
mockUseInvitationCheck.mockReturnValue({
|
||||
data: {
|
||||
is_valid: true,
|
||||
data: {
|
||||
workspace_name: 'Acme',
|
||||
workspace_id: 'workspace-id',
|
||||
email: 'invitee@example.com',
|
||||
},
|
||||
},
|
||||
refetch: mockRefetch,
|
||||
} as unknown as ReturnType<typeof useInvitationCheck>)
|
||||
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
|
||||
mockActivateMember.mockResolvedValue({ result: 'success' })
|
||||
})
|
||||
|
||||
describe('Activation payload', () => {
|
||||
it('should default language to the current UI locale', async () => {
|
||||
render(<InviteSettingsPage />)
|
||||
|
||||
fireEvent.change(screen.getByLabelText('login.name'), {
|
||||
target: { value: 'Invitee' },
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockActivateMember).toHaveBeenCalledWith({
|
||||
url: '/activate',
|
||||
body: {
|
||||
token: 'invite-token',
|
||||
name: 'Invitee',
|
||||
interface_language: 'zh-Hans',
|
||||
timezone: 'Asia/Shanghai',
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should fall back to configured default locale when current locale is unsupported', async () => {
|
||||
mockUseLocale.mockReturnValue('unsupported-locale' as ReturnType<typeof useLocale>)
|
||||
|
||||
render(<InviteSettingsPage />)
|
||||
|
||||
fireEvent.change(screen.getByLabelText('login.name'), {
|
||||
target: { value: 'Invitee' },
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: 'login.join Acme' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockActivateMember).toHaveBeenCalledWith({
|
||||
url: '/activate',
|
||||
body: {
|
||||
token: 'invite-token',
|
||||
name: 'Invitee',
|
||||
interface_language: 'en-US',
|
||||
timezone: 'Asia/Shanghai',
|
||||
},
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -11,14 +11,15 @@ import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import { LICENSE_LINK } from '@/constants/link'
|
||||
import { setLocaleOnClient } from '@/i18n-config'
|
||||
import { languages, LanguagesSupported } from '@/i18n-config/language'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { i18n, setLocaleOnClient } from '@/i18n-config'
|
||||
import { languages } from '@/i18n-config/language'
|
||||
import Link from '@/next/link'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { activateMember } from '@/service/common'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { useInvitationCheck } from '@/service/use-common'
|
||||
import { timezones } from '@/utils/timezone'
|
||||
import { getBrowserTimezone, timezones } from '@/utils/timezone'
|
||||
import { resolvePostLoginRedirect } from '../utils/post-login-redirect'
|
||||
|
||||
type LanguageSelectOption = {
|
||||
@ -43,15 +44,23 @@ const TIMEZONE_OPTIONS: TimezoneSelectOption[] = timezones.map(item => ({
|
||||
name: item.name,
|
||||
}))
|
||||
|
||||
const getInitialLanguage = (locale: Locale): Locale => {
|
||||
if (LANGUAGE_OPTIONS.some(item => item.value === locale))
|
||||
return locale
|
||||
|
||||
return i18n.defaultLocale
|
||||
}
|
||||
|
||||
export default function InviteSettingsPage() {
|
||||
const { t } = useTranslation()
|
||||
const { data: systemFeatures } = useSuspenseQuery(systemFeaturesQueryOptions())
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const token = decodeURIComponent(searchParams.get('invite_token') as string)
|
||||
const locale = useLocale()
|
||||
const [name, setName] = useState('')
|
||||
const [language, setLanguage] = useState(LanguagesSupported[0])
|
||||
const [timezone, setTimezone] = useState(() => Intl.DateTimeFormat().resolvedOptions().timeZone || 'America/Los_Angeles')
|
||||
const [language, setLanguage] = useState(() => getInitialLanguage(locale))
|
||||
const [timezone, setTimezone] = useState(() => getBrowserTimezone() || 'America/Los_Angeles')
|
||||
const selectedLanguage = LANGUAGE_OPTIONS.find(item => item.value === language)
|
||||
const selectedTimezone = TIMEZONE_OPTIONS.find(item => item.value === timezone)
|
||||
|
||||
|
||||
85
web/app/signup/set-password/__tests__/page.spec.tsx
Normal file
85
web/app/signup/set-password/__tests__/page.spec.tsx
Normal file
@ -0,0 +1,85 @@
|
||||
import type { MockedFunction } from 'vitest'
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useMailRegister } from '@/service/use-common'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
import ChangePasswordForm from '../page'
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/next/navigation', () => ({
|
||||
useRouter: vi.fn(),
|
||||
useSearchParams: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/service/use-common', () => ({
|
||||
useMailRegister: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/timezone', () => ({
|
||||
getBrowserTimezone: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/gtag', () => ({
|
||||
sendGAEvent: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/base/amplitude', () => ({
|
||||
trackEvent: vi.fn(),
|
||||
}))
|
||||
|
||||
vi.mock('@/utils/create-app-tracking', () => ({
|
||||
rememberCreateAppExternalAttribution: vi.fn(),
|
||||
}))
|
||||
|
||||
const mockRegister = vi.fn()
|
||||
const mockReplace = vi.fn()
|
||||
|
||||
const mockUseLocale = useLocale as unknown as MockedFunction<typeof useLocale>
|
||||
const mockUseSearchParams = useSearchParams as unknown as MockedFunction<typeof useSearchParams>
|
||||
const mockUseRouter = useRouter as unknown as MockedFunction<typeof useRouter>
|
||||
const mockUseMailRegister = useMailRegister as unknown as MockedFunction<typeof useMailRegister>
|
||||
const mockGetBrowserTimezone = getBrowserTimezone as unknown as MockedFunction<typeof getBrowserTimezone>
|
||||
|
||||
describe('Signup Set Password Page', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockUseLocale.mockReturnValue('zh-Hans')
|
||||
mockUseSearchParams.mockReturnValue(new URLSearchParams('token=register-token') as unknown as ReturnType<typeof useSearchParams>)
|
||||
mockUseRouter.mockReturnValue({ replace: mockReplace } as unknown as ReturnType<typeof useRouter>)
|
||||
mockUseMailRegister.mockReturnValue({
|
||||
mutateAsync: mockRegister,
|
||||
isPending: false,
|
||||
} as unknown as ReturnType<typeof useMailRegister>)
|
||||
mockGetBrowserTimezone.mockReturnValue('Asia/Shanghai')
|
||||
mockRegister.mockResolvedValue({ result: 'fail', data: {} })
|
||||
})
|
||||
|
||||
describe('Registration payload', () => {
|
||||
it('should submit locale and browser timezone when setting password', async () => {
|
||||
render(<ChangePasswordForm />)
|
||||
|
||||
fireEvent.change(screen.getByLabelText('common.account.newPassword'), {
|
||||
target: { value: 'ValidPass123!' },
|
||||
})
|
||||
fireEvent.change(screen.getByLabelText('common.account.confirmPassword'), {
|
||||
target: { value: 'ValidPass123!' },
|
||||
})
|
||||
fireEvent.click(screen.getByRole('button', { name: 'login.changePasswordBtn' }))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockRegister).toHaveBeenCalledWith({
|
||||
token: 'register-token',
|
||||
new_password: 'ValidPass123!',
|
||||
password_confirm: 'ValidPass123!',
|
||||
language: 'zh-Hans',
|
||||
timezone: 'Asia/Shanghai',
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@ -9,10 +9,12 @@ import { useTranslation } from 'react-i18next'
|
||||
import { trackEvent } from '@/app/components/base/amplitude'
|
||||
import Input from '@/app/components/base/input'
|
||||
import { validPassword } from '@/config'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import { useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useMailRegister } from '@/service/use-common'
|
||||
import { rememberCreateAppExternalAttribution } from '@/utils/create-app-tracking'
|
||||
import { sendGAEvent } from '@/utils/gtag'
|
||||
import { getBrowserTimezone } from '@/utils/timezone'
|
||||
|
||||
const parseUtmInfo = () => {
|
||||
const utmInfoStr = Cookies.get('utm_info')
|
||||
@ -32,6 +34,7 @@ const ChangePasswordForm = () => {
|
||||
const router = useRouter()
|
||||
const searchParams = useSearchParams()
|
||||
const token = decodeURIComponent(searchParams.get('token') || '')
|
||||
const locale = useLocale()
|
||||
|
||||
const [password, setPassword] = useState('')
|
||||
const [confirmPassword, setConfirmPassword] = useState('')
|
||||
@ -65,6 +68,8 @@ const ChangePasswordForm = () => {
|
||||
token,
|
||||
new_password: password,
|
||||
password_confirm: confirmPassword,
|
||||
language: locale,
|
||||
timezone: getBrowserTimezone(),
|
||||
})
|
||||
const { result } = res as MailRegisterResponse
|
||||
if (result === 'success') {
|
||||
@ -88,7 +93,7 @@ const ChangePasswordForm = () => {
|
||||
catch (error) {
|
||||
console.error(error)
|
||||
}
|
||||
}, [password, token, valid, confirmPassword, register])
|
||||
}, [password, token, valid, confirmPassword, register, locale])
|
||||
|
||||
return (
|
||||
<div className={
|
||||
|
||||
@ -339,7 +339,13 @@ export const uploadRemoteFileInfo = (url: string, isPublic?: boolean, silent?: b
|
||||
export const sendEMailLoginCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string }> =>
|
||||
post<CommonResponse & { data: string }>('/email-code-login', { body: { email, language } })
|
||||
|
||||
export const emailLoginWithCode = (data: { email: string, code: string, token: string, language: string }): Promise<LoginResponse> =>
|
||||
export const emailLoginWithCode = (data: {
|
||||
email: string
|
||||
code: string
|
||||
token: string
|
||||
language: string
|
||||
timezone?: string
|
||||
}): Promise<LoginResponse> =>
|
||||
post<LoginResponse>('/email-code-login/validity', { body: data })
|
||||
|
||||
export const sendResetPasswordCode = (email: string, language = 'en-US'): Promise<CommonResponse & { data: string, message?: string, code?: string }> =>
|
||||
|
||||
@ -178,7 +178,13 @@ export type MailRegisterResponse = { result: string, data: {} }
|
||||
export const useMailRegister = () => {
|
||||
return useMutation({
|
||||
mutationKey: [NAME_SPACE, 'mail-register'],
|
||||
mutationFn: (body: { token: string, new_password: string, password_confirm: string }) => {
|
||||
mutationFn: (body: {
|
||||
token: string
|
||||
new_password: string
|
||||
password_confirm: string
|
||||
language?: string
|
||||
timezone?: string
|
||||
}) => {
|
||||
return post<MailRegisterResponse>('/email-register', { body })
|
||||
},
|
||||
})
|
||||
|
||||
@ -5,3 +5,10 @@ type Item = {
|
||||
name: string
|
||||
}
|
||||
export const timezones: Item[] = tz
|
||||
|
||||
export const getBrowserTimezone = () => {
|
||||
if (typeof Intl === 'undefined')
|
||||
return undefined
|
||||
|
||||
return Intl.DateTimeFormat().resolvedOptions().timeZone || undefined
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user