mirror of
https://github.com/langgenius/dify.git
synced 2026-05-06 02:18:08 +08:00
refactor: replace localStorage with HTTP-only cookies for auth tokens (#24365)
Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com> Signed-off-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Yunlu Wen <wylswz@163.com> Co-authored-by: Joel <iamjoel007@gmail.com> Co-authored-by: GareArc <chen4851@purdue.edu> Co-authored-by: NFish <douxc512@gmail.com> Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com> Co-authored-by: minglu7 <1347866672@qq.com> Co-authored-by: Ponder <ruan.lj@foxmail.com> Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: heyszt <270985384@qq.com> Co-authored-by: Asuka Minato <i@asukaminato.eu.org> Co-authored-by: Guangdong Liu <liugddx@gmail.com> Co-authored-by: Eric Guo <eric.guocz@gmail.com> Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com> Co-authored-by: XlKsyt <caixuesen@outlook.com> Co-authored-by: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com> Co-authored-by: crazywoola <427733928@qq.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com> Co-authored-by: hj24 <mambahj24@gmail.com> Co-authored-by: GuanMu <ballmanjq@gmail.com> Co-authored-by: 非法操作 <hjlarry@163.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tonlo <123lzs123@gmail.com> Co-authored-by: Yusuke Yamada <yamachu.dev@gmail.com> Co-authored-by: Novice <novice12185727@gmail.com> Co-authored-by: kenwoodjw <blackxin55+@gmail.com> Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com> Co-authored-by: znn <jubinkumarsoni@gmail.com> Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com>
This commit is contained in:
@ -55,3 +55,12 @@ else:
|
||||
"properties",
|
||||
}
|
||||
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
|
||||
|
||||
COOKIE_NAME_ACCESS_TOKEN = "access_token"
|
||||
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
|
||||
COOKIE_NAME_PASSPORT = "passport"
|
||||
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
|
||||
|
||||
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
|
||||
HEADER_NAME_APP_CODE = "X-App-Code"
|
||||
HEADER_NAME_PASSPORT = "X-App-Passport"
|
||||
|
||||
@ -15,6 +15,7 @@ from constants.languages import supported_language
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console.wraps import only_edition_cloud
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
|
||||
|
||||
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
|
||||
if not dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
auth_token = extract_access_token(request)
|
||||
if not auth_token:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
if auth_token != dify_config.ADMIN_API_KEY:
|
||||
raise Unauthorized("API key is invalid.")
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import flask_login
|
||||
from flask import request
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
import services
|
||||
@ -25,6 +25,16 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
|
||||
from events.tenant_event import tenant_was_created
|
||||
from libs.helper import email, extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
clear_csrf_token_from_cookie,
|
||||
clear_refresh_token_from_cookie,
|
||||
extract_access_token,
|
||||
extract_csrf_token,
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.errors.account import AccountRegisterError
|
||||
@ -89,20 +99,36 @@ class LoginApi(Resource):
|
||||
|
||||
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
def get(self):
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
account = current_user
|
||||
if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
return {"result": "success"}
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
return {"result": "success"}
|
||||
response = make_response({"result": "success"})
|
||||
else:
|
||||
AccountService.logout(account=account)
|
||||
flask_login.logout_user()
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Clear cookies on logout
|
||||
clear_access_token_from_cookie(response)
|
||||
clear_refresh_token_from_cookie(response)
|
||||
clear_csrf_token_from_cookie(response)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/reset-password")
|
||||
@ -227,17 +253,46 @@ class EmailCodeLoginApi(Resource):
|
||||
raise WorkspacesLimitExceeded()
|
||||
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": token_pair.model_dump()}
|
||||
|
||||
# Create response with cookies instead of returning tokens in body
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
# Set HTTP-only secure cookies for tokens
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/refresh-token")
|
||||
class RefreshTokenApi(Resource):
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser().add_argument("refresh_token", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
# Get refresh token from cookie instead of request body
|
||||
refresh_token = request.cookies.get("refresh_token")
|
||||
|
||||
if not refresh_token:
|
||||
return {"result": "fail", "message": "No refresh token provided"}, 401
|
||||
|
||||
try:
|
||||
new_token_pair = AccountService.refresh_token(args["refresh_token"])
|
||||
return {"result": "success", "data": new_token_pair.model_dump()}
|
||||
new_token_pair = AccountService.refresh_token(refresh_token)
|
||||
|
||||
# Create response with new cookies
|
||||
response = make_response({"result": "success"})
|
||||
|
||||
# Update cookies with new tokens
|
||||
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
|
||||
set_access_token_to_cookie(request, response, new_token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
|
||||
return response
|
||||
except Exception as e:
|
||||
return {"result": "fail", "data": str(e)}, 401
|
||||
return {"result": "fail", "message": str(e)}, 401
|
||||
|
||||
|
||||
# this api helps frontend to check whether user is authenticated
|
||||
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
||||
@console_ns.route("/login/status")
|
||||
class LoginStatus(Resource):
|
||||
def get(self):
|
||||
token = extract_access_token(request)
|
||||
csrf_token = extract_csrf_token(request)
|
||||
return {"logged_in": bool(token) and bool(csrf_token)}
|
||||
|
||||
@ -14,6 +14,11 @@ from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
|
||||
from libs.token import (
|
||||
set_access_token_to_cookie,
|
||||
set_csrf_token_to_cookie,
|
||||
set_refresh_token_to_cookie,
|
||||
)
|
||||
from models import Account, AccountStatus
|
||||
from services.account_service import AccountService, RegisterService, TenantService
|
||||
from services.billing_service import BillingService
|
||||
@ -152,9 +157,12 @@ class OAuthCallback(Resource):
|
||||
ip_address=extract_remote_ip(request),
|
||||
)
|
||||
|
||||
return redirect(
|
||||
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
|
||||
)
|
||||
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||
|
||||
set_access_token_to_cookie(request, response, token_pair.access_token)
|
||||
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
|
||||
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
|
||||
return response
|
||||
|
||||
|
||||
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:
|
||||
|
||||
@ -15,7 +15,6 @@ from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -67,31 +66,26 @@ class InstalledAppsListApi(Resource):
|
||||
|
||||
# Pre-filter out apps without setting or with sso_verified
|
||||
filtered_installed_apps = []
|
||||
app_id_to_app_code = {}
|
||||
|
||||
for installed_app in installed_app_list:
|
||||
app_id = installed_app["app"].id
|
||||
webapp_setting = webapp_settings.get(app_id)
|
||||
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
|
||||
continue
|
||||
app_code = AppService.get_app_code_by_id(str(app_id))
|
||||
app_id_to_app_code[app_id] = app_code
|
||||
filtered_installed_apps.append(installed_app)
|
||||
|
||||
app_codes = list(app_id_to_app_code.values())
|
||||
|
||||
# Batch permission check
|
||||
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
|
||||
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
|
||||
user_id=user_id,
|
||||
app_codes=app_codes,
|
||||
app_ids=app_ids,
|
||||
)
|
||||
|
||||
# Keep only allowed apps
|
||||
res = []
|
||||
for installed_app in filtered_installed_apps:
|
||||
app_id = installed_app["app"].id
|
||||
app_code = app_id_to_app_code[app_id]
|
||||
if permissions.get(app_code):
|
||||
if permissions.get(app_id):
|
||||
res.append(installed_app)
|
||||
|
||||
installed_app_list = res
|
||||
|
||||
@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import InstalledApp
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@ -56,10 +55,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
|
||||
feature = FeatureService.get_system_features()
|
||||
if feature.webapp_auth.enabled:
|
||||
app_id = installed_app.app_id
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=str(current_user.id),
|
||||
app_code=app_code,
|
||||
app_id=app_id,
|
||||
)
|
||||
if not res:
|
||||
raise AppAccessDeniedError()
|
||||
|
||||
@ -4,12 +4,14 @@ from flask import request
|
||||
from flask_restx import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.common import fields
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import AppUnavailableError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, AppMode
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
@ -133,18 +135,19 @@ class AppWebAuthPermission(Resource):
|
||||
)
|
||||
def get(self):
|
||||
user_id = "visitor"
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
app_id = request.args.get("appId")
|
||||
if not app_id or not app_code:
|
||||
raise ValueError("appId must be provided")
|
||||
|
||||
require_permission_check = WebAppAuthService.is_app_require_permission_check(app_id=app_id)
|
||||
if not require_permission_check:
|
||||
return {"result": True}
|
||||
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
tk = extract_webapp_passport(app_code, request)
|
||||
if not tk:
|
||||
raise Unauthorized("Access token is missing.")
|
||||
decoded = PassportService().verify(tk)
|
||||
user_id = decoded.get("user_id", "visitor")
|
||||
except Unauthorized:
|
||||
@ -157,13 +160,7 @@ class AppWebAuthPermission(Resource):
|
||||
if not features.webapp_auth.enabled:
|
||||
return {"result": True}
|
||||
|
||||
parser = reqparse.RequestParser().add_argument("appId", type=str, required=True, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
app_id = args["appId"]
|
||||
app_code = AppService.get_app_code_by_id(app_id)
|
||||
|
||||
res = True
|
||||
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
|
||||
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_id)
|
||||
return {"result": res}
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from jwt import InvalidTokenError
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.console.auth.error import (
|
||||
AuthenticationFailedError,
|
||||
EmailCodeError,
|
||||
@ -10,9 +12,16 @@ from controllers.console.auth.error import (
|
||||
from controllers.console.error import AccountBannedError
|
||||
from controllers.console.wraps import only_edition_enterprise, setup_required
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.wraps import decode_jwt_token
|
||||
from libs.helper import email
|
||||
from libs.passport import PassportService
|
||||
from libs.password import valid_password
|
||||
from libs.token import (
|
||||
clear_access_token_from_cookie,
|
||||
extract_access_token,
|
||||
)
|
||||
from services.account_service import AccountService
|
||||
from services.app_service import AppService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
|
||||
|
||||
@ -52,17 +61,75 @@ class LoginApi(Resource):
|
||||
raise AuthenticationFailedError()
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
|
||||
# class LogoutApi(Resource):
|
||||
# @setup_required
|
||||
# def get(self):
|
||||
# account = cast(Account, flask_login.current_user)
|
||||
# if isinstance(account, flask_login.AnonymousUserMixin):
|
||||
# return {"result": "success"}
|
||||
# flask_login.logout_user()
|
||||
# return {"result": "success"}
|
||||
# this api helps frontend to check whether user is authenticated
|
||||
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
|
||||
@web_ns.route("/login/status")
|
||||
class LoginStatusApi(Resource):
|
||||
@setup_required
|
||||
@web_ns.doc("web_app_login_status")
|
||||
@web_ns.doc(description="Check login status")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Login status",
|
||||
401: "Login status",
|
||||
}
|
||||
)
|
||||
def get(self):
|
||||
app_code = request.args.get("app_code")
|
||||
token = extract_access_token(request)
|
||||
if not app_code:
|
||||
return {
|
||||
"logged_in": bool(token),
|
||||
"app_logged_in": False,
|
||||
}
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check(
|
||||
app_id=app_id
|
||||
)
|
||||
user_logged_in = False
|
||||
|
||||
if is_public:
|
||||
user_logged_in = True
|
||||
else:
|
||||
try:
|
||||
PassportService().verify(token=token)
|
||||
user_logged_in = True
|
||||
except Exception:
|
||||
user_logged_in = False
|
||||
|
||||
try:
|
||||
_ = decode_jwt_token(app_code=app_code)
|
||||
app_logged_in = True
|
||||
except Exception:
|
||||
app_logged_in = False
|
||||
|
||||
return {
|
||||
"logged_in": user_logged_in,
|
||||
"app_logged_in": app_logged_in,
|
||||
}
|
||||
|
||||
|
||||
@web_ns.route("/logout")
|
||||
class LogoutApi(Resource):
|
||||
@setup_required
|
||||
@web_ns.doc("web_app_logout")
|
||||
@web_ns.doc(description="Logout user from web application")
|
||||
@web_ns.doc(
|
||||
responses={
|
||||
200: "Logout successful",
|
||||
}
|
||||
)
|
||||
def post(self):
|
||||
response = make_response({"result": "success"})
|
||||
# enterprise SSO sets same site to None in https deployment
|
||||
# so we need to logout by calling api
|
||||
clear_access_token_from_cookie(response, samesite="None")
|
||||
return response
|
||||
|
||||
|
||||
@web_ns.route("/email-code-login")
|
||||
@ -96,7 +163,6 @@ class EmailCodeLoginSendEmailApi(Resource):
|
||||
raise AuthenticationFailedError()
|
||||
else:
|
||||
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
|
||||
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
|
||||
@ -142,4 +208,6 @@ class EmailCodeLoginApi(Resource):
|
||||
|
||||
token = WebAppAuthService.login(account=account)
|
||||
AccountService.reset_login_error_rate_limit(args["email"])
|
||||
return {"result": "success", "data": {"access_token": token}}
|
||||
response = make_response({"result": "success", "data": {"access_token": token}})
|
||||
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
|
||||
return response
|
||||
|
||||
@ -1,17 +1,20 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import request
|
||||
from flask import make_response, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import func, select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
|
||||
@ -32,15 +35,15 @@ class PassportResource(Resource):
|
||||
)
|
||||
def get(self):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = request.headers.get("X-App-Code")
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
user_id = request.args.get("user_id")
|
||||
web_app_access_token = request.args.get("web_app_access_token")
|
||||
access_token = extract_access_token(request)
|
||||
|
||||
if app_code is None:
|
||||
raise Unauthorized("X-App-Code header is missing.")
|
||||
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
# exchange token for enterprise logined web user
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
|
||||
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
|
||||
if enterprise_user_decoded:
|
||||
# a web user has already logged in, exchange a token for this app without redirecting to the login page
|
||||
return exchange_token_for_existing_web_user(
|
||||
@ -48,7 +51,7 @@ class PassportResource(Resource):
|
||||
)
|
||||
|
||||
if system_features.webapp_auth.enabled:
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
if not app_settings or not app_settings.access_mode == "public":
|
||||
raise WebAppAuthRequiredError()
|
||||
|
||||
@ -99,9 +102,12 @@ class PassportResource(Resource):
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
"access_token": tk,
|
||||
}
|
||||
response = make_response(
|
||||
{
|
||||
"access_token": tk,
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
def decode_enterprise_webapp_user_id(jwt_token: str | None):
|
||||
@ -189,9 +195,12 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
|
||||
"exp": exp,
|
||||
}
|
||||
token: str = PassportService().issue(payload)
|
||||
return {
|
||||
"access_token": token,
|
||||
}
|
||||
resp = make_response(
|
||||
{
|
||||
"access_token": token,
|
||||
}
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
@ -224,9 +233,12 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
|
||||
|
||||
tk = PassportService().issue(payload)
|
||||
|
||||
return {
|
||||
"access_token": tk,
|
||||
}
|
||||
resp = make_response(
|
||||
{
|
||||
"access_token": tk,
|
||||
}
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
def generate_session_id():
|
||||
|
||||
@ -9,10 +9,13 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from constants import HEADER_NAME_APP_CODE
|
||||
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_webapp_passport
|
||||
from models.model import App, EndUser, Site
|
||||
from services.app_service import AppService
|
||||
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
|
||||
from services.feature_service import FeatureService
|
||||
from services.webapp_auth_service import WebAppAuthService
|
||||
@ -35,22 +38,14 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
|
||||
return decorator
|
||||
|
||||
|
||||
def decode_jwt_token():
|
||||
def decode_jwt_token(app_code: str | None = None):
|
||||
system_features = FeatureService.get_system_features()
|
||||
app_code = str(request.headers.get("X-App-Code"))
|
||||
if not app_code:
|
||||
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
|
||||
try:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None:
|
||||
raise Unauthorized("Authorization header is missing.")
|
||||
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
|
||||
auth_scheme, tk = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
tk = extract_webapp_passport(app_code, request)
|
||||
if not tk:
|
||||
raise Unauthorized("App token is missing.")
|
||||
decoded = PassportService().verify(tk)
|
||||
app_code = decoded.get("app_code")
|
||||
app_id = decoded.get("app_id")
|
||||
@ -72,7 +67,8 @@ def decode_jwt_token():
|
||||
app_web_auth_enabled = False
|
||||
webapp_settings = None
|
||||
if system_features.webapp_auth.enabled:
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
|
||||
if not webapp_settings:
|
||||
raise NotFound("Web app settings not found.")
|
||||
app_web_auth_enabled = webapp_settings.access_mode != "public"
|
||||
@ -87,8 +83,9 @@ def decode_jwt_token():
|
||||
if system_features.webapp_auth.enabled:
|
||||
if not app_code:
|
||||
raise Unauthorized("Please re-login to access the web app.")
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
app_web_auth_enabled = (
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
|
||||
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
|
||||
)
|
||||
if app_web_auth_enabled:
|
||||
raise WebAppAuthRequiredError()
|
||||
@ -129,7 +126,8 @@ def _validate_user_accessibility(
|
||||
raise WebAppAuthRequiredError("Web app settings not found.")
|
||||
|
||||
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_id):
|
||||
raise WebAppAuthAccessDeniedError()
|
||||
|
||||
auth_type = decoded.get("auth_type")
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
from configs import dify_config
|
||||
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
|
||||
from dify_app import DifyApp
|
||||
|
||||
|
||||
@ -16,7 +17,7 @@ def init_app(app: DifyApp):
|
||||
|
||||
CORS(
|
||||
service_api_bp,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
@ -25,7 +26,7 @@ def init_app(app: DifyApp):
|
||||
web_bp,
|
||||
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
@ -35,7 +36,7 @@ def init_app(app: DifyApp):
|
||||
console_app_bp,
|
||||
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=["Content-Type", "Authorization"],
|
||||
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
expose_headers=["X-Version", "X-Env"],
|
||||
)
|
||||
@ -43,7 +44,7 @@ def init_app(app: DifyApp):
|
||||
|
||||
CORS(
|
||||
files_bp,
|
||||
allow_headers=["Content-Type"],
|
||||
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
|
||||
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
@ -9,6 +9,7 @@ from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -24,20 +25,10 @@ def load_user_from_request(request_from_flask_login):
|
||||
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
|
||||
return None
|
||||
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
auth_token: str | None = None
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
auth_scheme, auth_token = auth_header.split(maxsplit=1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
|
||||
else:
|
||||
auth_token = request.args.get("_token")
|
||||
auth_token = extract_access_token(request)
|
||||
|
||||
# Check for admin API key authentication first
|
||||
if dify_config.ADMIN_API_KEY_ENABLE and auth_header:
|
||||
if dify_config.ADMIN_API_KEY_ENABLE and auth_token:
|
||||
admin_api_key = dify_config.ADMIN_API_KEY
|
||||
if admin_api_key and admin_api_key == auth_token:
|
||||
workspace_id = request.headers.get("X-WORKSPACE-ID")
|
||||
|
||||
@ -9,7 +9,9 @@ from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.http import HTTP_STATUS_CODES
|
||||
|
||||
from configs import dify_config
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.token import is_secure
|
||||
|
||||
|
||||
def http_status_message(code):
|
||||
@ -67,6 +69,19 @@ def register_external_error_handlers(api: Api):
|
||||
# If you need WWW-Authenticate for 401, add it to headers
|
||||
if status_code == 401:
|
||||
headers["WWW-Authenticate"] = 'Bearer realm="api"'
|
||||
# Check if this is a forced logout error - clear cookies
|
||||
error_code = getattr(e, "error_code", None)
|
||||
if error_code == "unauthorized_and_force_logout":
|
||||
# Add Set-Cookie headers to clear auth cookies
|
||||
|
||||
secure = is_secure()
|
||||
# response is not accessible, so we need to do it ugly
|
||||
common_part = "Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly"
|
||||
headers["Set-Cookie"] = [
|
||||
f'{COOKIE_NAME_ACCESS_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
||||
f'{COOKIE_NAME_CSRF_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
||||
f'{COOKIE_NAME_REFRESH_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
|
||||
]
|
||||
return data, status_code, headers
|
||||
|
||||
_ = handle_http_exception
|
||||
|
||||
@ -7,6 +7,7 @@ from flask_login.config import EXEMPT_METHODS # type: ignore
|
||||
from werkzeug.local import LocalProxy
|
||||
|
||||
from configs import dify_config
|
||||
from libs.token import check_csrf_token
|
||||
from models import Account
|
||||
from models.model import EndUser
|
||||
|
||||
@ -73,6 +74,9 @@ def login_required(func: Callable[P, R]):
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
208
api/libs/token.py
Normal file
208
api/libs/token.py
Normal file
@ -0,0 +1,208 @@
|
||||
import logging
|
||||
import re
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from flask import Request
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
from configs import dify_config
|
||||
from constants import (
|
||||
COOKIE_NAME_ACCESS_TOKEN,
|
||||
COOKIE_NAME_CSRF_TOKEN,
|
||||
COOKIE_NAME_PASSPORT,
|
||||
COOKIE_NAME_REFRESH_TOKEN,
|
||||
HEADER_NAME_CSRF_TOKEN,
|
||||
HEADER_NAME_PASSPORT,
|
||||
)
|
||||
from libs.passport import PassportService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
CSRF_WHITE_LIST = [
|
||||
re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"),
|
||||
]
|
||||
|
||||
|
||||
# server is behind a reverse proxy, so we need to check the url
|
||||
def is_secure() -> bool:
|
||||
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
|
||||
|
||||
|
||||
def _real_cookie_name(cookie_name: str) -> str:
|
||||
if is_secure():
|
||||
return "__Host-" + cookie_name
|
||||
else:
|
||||
return cookie_name
|
||||
|
||||
|
||||
def _try_extract_from_header(request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract access token from header
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header:
|
||||
if " " not in auth_header:
|
||||
return None
|
||||
else:
|
||||
auth_scheme, auth_token = auth_header.split(None, 1)
|
||||
auth_scheme = auth_scheme.lower()
|
||||
if auth_scheme != "bearer":
|
||||
return None
|
||||
else:
|
||||
return auth_token
|
||||
return None
|
||||
|
||||
|
||||
def extract_csrf_token(request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract CSRF token from header or cookie.
|
||||
"""
|
||||
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
|
||||
|
||||
|
||||
def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract CSRF token from cookie.
|
||||
"""
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract access token from cookie, header or params.
|
||||
|
||||
Access token is either for console session or webapp passport exchange.
|
||||
"""
|
||||
|
||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
|
||||
"""
|
||||
Try to extract app token from header or params.
|
||||
|
||||
Webapp access token (part of passport) is only used for webapp session.
|
||||
"""
|
||||
|
||||
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
|
||||
|
||||
def _try_extract_passport_token_from_header(request: Request) -> str | None:
|
||||
return request.headers.get(HEADER_NAME_PASSPORT)
|
||||
|
||||
ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
|
||||
return ret
|
||||
|
||||
|
||||
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_secure(),
|
||||
samesite=samesite,
|
||||
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def set_refresh_token_to_cookie(request: Request, response: Response, token: str):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def set_csrf_token_to_cookie(request: Request, response: Response, token: str):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
|
||||
value=token,
|
||||
httponly=False,
|
||||
secure=is_secure(),
|
||||
samesite="Lax",
|
||||
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
|
||||
path="/",
|
||||
)
|
||||
|
||||
|
||||
def _clear_cookie(
|
||||
response: Response,
|
||||
cookie_name: str,
|
||||
samesite: str = "Lax",
|
||||
http_only: bool = True,
|
||||
):
|
||||
response.set_cookie(
|
||||
_real_cookie_name(cookie_name),
|
||||
"",
|
||||
expires=0,
|
||||
path="/",
|
||||
secure=is_secure(),
|
||||
httponly=http_only,
|
||||
samesite=samesite,
|
||||
)
|
||||
|
||||
|
||||
def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
|
||||
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
|
||||
|
||||
|
||||
def clear_refresh_token_from_cookie(response: Response):
|
||||
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
|
||||
|
||||
|
||||
def clear_csrf_token_from_cookie(response: Response):
|
||||
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
|
||||
|
||||
|
||||
def check_csrf_token(request: Request, user_id: str):
|
||||
# some apis are sent by beacon, so we need to bypass csrf token check
|
||||
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
|
||||
def _unauthorized():
|
||||
raise Unauthorized("CSRF token is missing or invalid.")
|
||||
|
||||
for pattern in CSRF_WHITE_LIST:
|
||||
if pattern.match(request.path):
|
||||
return
|
||||
|
||||
csrf_token = extract_csrf_token(request)
|
||||
csrf_token_from_cookie = extract_csrf_token_from_cookie(request)
|
||||
|
||||
if csrf_token != csrf_token_from_cookie:
|
||||
_unauthorized()
|
||||
|
||||
if not csrf_token:
|
||||
_unauthorized()
|
||||
verified = {}
|
||||
try:
|
||||
verified = PassportService().verify(csrf_token)
|
||||
except:
|
||||
_unauthorized()
|
||||
|
||||
if verified.get("sub") != user_id:
|
||||
_unauthorized()
|
||||
|
||||
exp: int | None = verified.get("exp")
|
||||
if not exp:
|
||||
_unauthorized()
|
||||
else:
|
||||
time_now = int(datetime.now().timestamp())
|
||||
if exp < time_now:
|
||||
_unauthorized()
|
||||
|
||||
|
||||
def generate_csrf_token(user_id: str) -> str:
|
||||
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
payload = {
|
||||
"exp": int(exp_dt.timestamp()),
|
||||
"sub": user_id,
|
||||
}
|
||||
return PassportService().issue(payload)
|
||||
@ -22,6 +22,7 @@ from libs.helper import RateLimiter, TokenManager
|
||||
from libs.passport import PassportService
|
||||
from libs.password import compare_password, hash_password, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
from libs.token import generate_csrf_token
|
||||
from models.account import (
|
||||
Account,
|
||||
AccountIntegrate,
|
||||
@ -76,6 +77,7 @@ logger = logging.getLogger(__name__)
|
||||
class TokenPair(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
csrf_token: str
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
@ -403,10 +405,11 @@ class AccountService:
|
||||
|
||||
access_token = AccountService.get_account_jwt_token(account=account)
|
||||
refresh_token = _generate_refresh_token()
|
||||
csrf_token = generate_csrf_token(account.id)
|
||||
|
||||
AccountService._store_refresh_token(refresh_token, account.id)
|
||||
|
||||
return TokenPair(access_token=access_token, refresh_token=refresh_token)
|
||||
return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token)
|
||||
|
||||
@staticmethod
|
||||
def logout(*, account: Account):
|
||||
@ -431,8 +434,9 @@ class AccountService:
|
||||
|
||||
AccountService._delete_refresh_token(refresh_token, account.id)
|
||||
AccountService._store_refresh_token(new_refresh_token, account.id)
|
||||
csrf_token = generate_csrf_token(account.id)
|
||||
|
||||
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
|
||||
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
|
||||
|
||||
@staticmethod
|
||||
def load_logged_in_account(*, account_id: str):
|
||||
|
||||
@ -46,17 +46,17 @@ class EnterpriseService:
|
||||
|
||||
class WebAppAuth:
|
||||
@classmethod
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
|
||||
params = {"userId": user_id, "appCode": app_code}
|
||||
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
|
||||
params = {"userId": user_id, "appId": app_id}
|
||||
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
|
||||
|
||||
return data.get("result", False)
|
||||
|
||||
@classmethod
|
||||
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
|
||||
if not app_codes:
|
||||
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]):
|
||||
if not app_ids:
|
||||
return {}
|
||||
body = {"userId": user_id, "appCodes": app_codes}
|
||||
body = {"userId": user_id, "appIds": app_ids}
|
||||
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
|
||||
if not data:
|
||||
raise ValueError("No data found.")
|
||||
|
||||
@ -172,7 +172,8 @@ class WebAppAuthService:
|
||||
return WebAppAuthType.EXTERNAL
|
||||
|
||||
if app_code:
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
|
||||
app_id = AppService.get_app_id_by_code(app_code)
|
||||
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
|
||||
return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
|
||||
|
||||
raise ValueError("Could not determine app authentication type.")
|
||||
|
||||
@ -863,13 +863,14 @@ class TestWebAppAuthService:
|
||||
- Mock service integration
|
||||
"""
|
||||
# Arrange: Setup mock for enterprise service
|
||||
mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
|
||||
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
|
||||
setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
|
||||
mock_external_service_dependencies[
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth
|
||||
].WebAppAuth.get_app_access_mode_by_id.return_value = setting
|
||||
|
||||
# Act: Execute authentication type determination
|
||||
result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
|
||||
result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
|
||||
|
||||
# Assert: Verify correct result
|
||||
assert result == WebAppAuthType.EXTERNAL
|
||||
@ -877,7 +878,7 @@ class TestWebAppAuthService:
|
||||
# Verify mock service was called correctly
|
||||
mock_external_service_dependencies[
|
||||
"enterprise_service"
|
||||
].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code")
|
||||
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
|
||||
|
||||
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
"""
|
||||
|
||||
@ -179,9 +179,7 @@ class TestOAuthCallback:
|
||||
|
||||
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
|
||||
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
|
||||
mock_redirect.assert_called_once_with(
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||
)
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("exception", "expected_error"),
|
||||
@ -224,8 +222,8 @@ class TestOAuthCallback:
|
||||
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
|
||||
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
|
||||
(
|
||||
AccountStatus.CLOSED,
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
|
||||
AccountStatus.CLOSED.value,
|
||||
"http://localhost:3000",
|
||||
),
|
||||
],
|
||||
)
|
||||
@ -268,6 +266,7 @@ class TestOAuthCallback:
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
@ -299,6 +298,12 @@ class TestOAuthCallback:
|
||||
mock_account.status = AccountStatus.PENDING
|
||||
mock_generate_account.return_value = mock_account
|
||||
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
|
||||
resource.get("github")
|
||||
|
||||
@ -361,6 +366,7 @@ class TestOAuthCallback:
|
||||
mock_token_pair = MagicMock()
|
||||
mock_token_pair.access_token = "jwt_access_token"
|
||||
mock_token_pair.refresh_token = "jwt_refresh_token"
|
||||
mock_token_pair.csrf_token = "csrf_token"
|
||||
mock_account_service.login.return_value = mock_token_pair
|
||||
|
||||
# Execute OAuth callback
|
||||
@ -368,9 +374,7 @@ class TestOAuthCallback:
|
||||
resource.get("github")
|
||||
|
||||
# Verify current behavior: login succeeds (this is NOT ideal)
|
||||
mock_redirect.assert_called_once_with(
|
||||
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
|
||||
)
|
||||
mock_redirect.assert_called_once_with("http://localhost:3000")
|
||||
mock_account_service.login.assert_called_once()
|
||||
|
||||
# Document expected behavior in comments:
|
||||
|
||||
@ -2,7 +2,9 @@ from flask import Blueprint, Flask
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import BadRequest, Unauthorized
|
||||
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
|
||||
from core.errors.error import AppInvokeQuotaExceededError
|
||||
from libs.exception import BaseHTTPException
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
|
||||
@ -120,3 +122,66 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
|
||||
assert res.status_code in (400, 429)
|
||||
finally:
|
||||
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
|
||||
|
||||
|
||||
def test_unauthorized_and_force_logout_clears_cookies():
|
||||
"""Test that UnauthorizedAndForceLogout error clears auth cookies"""
|
||||
|
||||
class UnauthorizedAndForceLogout(BaseHTTPException):
|
||||
error_code = "unauthorized_and_force_logout"
|
||||
description = "Unauthorized and force logout."
|
||||
code = 401
|
||||
|
||||
app = Flask(__name__)
|
||||
bp = Blueprint("test", __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
@api.route("/force-logout")
|
||||
class ForceLogout(Resource): # type: ignore
|
||||
def get(self): # type: ignore
|
||||
raise UnauthorizedAndForceLogout()
|
||||
|
||||
app.register_blueprint(bp, url_prefix="/api")
|
||||
client = app.test_client()
|
||||
|
||||
# Set cookies first
|
||||
client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token")
|
||||
client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token")
|
||||
client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token")
|
||||
|
||||
# Make request that should trigger cookie clearing
|
||||
res = client.get("/api/force-logout")
|
||||
|
||||
# Verify response
|
||||
assert res.status_code == 401
|
||||
data = res.get_json()
|
||||
assert data["code"] == "unauthorized_and_force_logout"
|
||||
assert data["status"] == 401
|
||||
assert "WWW-Authenticate" in res.headers
|
||||
|
||||
# Verify Set-Cookie headers are present to clear cookies
|
||||
set_cookie_headers = res.headers.getlist("Set-Cookie")
|
||||
assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}"
|
||||
|
||||
# Verify each cookie is being cleared (empty value and expired)
|
||||
cookie_names_found = set()
|
||||
for cookie_header in set_cookie_headers:
|
||||
# Check for cookie names
|
||||
if COOKIE_NAME_ACCESS_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header # Empty value
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired
|
||||
elif COOKIE_NAME_CSRF_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header
|
||||
elif COOKIE_NAME_REFRESH_TOKEN in cookie_header:
|
||||
cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN)
|
||||
assert '""' in cookie_header or "=" in cookie_header
|
||||
assert "Expires=Thu, 01 Jan 1970" in cookie_header
|
||||
|
||||
# Verify all three cookies are present
|
||||
assert len(cookie_names_found) == 3
|
||||
assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found
|
||||
assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found
|
||||
assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found
|
||||
|
||||
@ -19,10 +19,15 @@ class MockUser(UserMixin):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
def mock_csrf_check(*args, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
@ -39,6 +44,7 @@ class TestLoginRequired:
|
||||
|
||||
return app
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@ -53,6 +59,7 @@ class TestLoginRequired:
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
|
||||
@ -68,6 +75,7 @@ class TestLoginRequired:
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
|
||||
@ -87,6 +95,7 @@ class TestLoginRequired:
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@ -103,6 +112,7 @@ class TestLoginRequired:
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@ -120,6 +130,7 @@ class TestLoginRequired:
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
@patch("libs.login.check_csrf_token", mock_csrf_check)
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
|
||||
23
api/tests/unit_tests/libs/test_token.py
Normal file
23
api/tests/unit_tests/libs/test_token.py
Normal file
@ -0,0 +1,23 @@
|
||||
from constants import COOKIE_NAME_ACCESS_TOKEN
|
||||
from libs.token import extract_access_token
|
||||
|
||||
|
||||
class MockRequest:
|
||||
def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
|
||||
self.headers: dict[str, str] = headers
|
||||
self.cookies: dict[str, str] = cookies
|
||||
self.args: dict[str, str] = args
|
||||
|
||||
|
||||
def test_extract_access_token():
|
||||
def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
|
||||
return MockRequest(headers, cookies, args)
|
||||
|
||||
test_cases = [
|
||||
(_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"),
|
||||
(_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"),
|
||||
(_mock_request({}, {}, {}), None),
|
||||
(_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None),
|
||||
]
|
||||
for request, expected in test_cases:
|
||||
assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType]
|
||||
Reference in New Issue
Block a user