refactor: replace localStorage with HTTP-only cookies for auth tokens (#24365)

Signed-off-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Signed-off-by: lyzno1 <yuanyouhuilyz@gmail.com>
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <62310815+github-advanced-security[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Yunlu Wen <wylswz@163.com>
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: GareArc <chen4851@purdue.edu>
Co-authored-by: NFish <douxc512@gmail.com>
Co-authored-by: Davide Delbianco <davide.delbianco@outlook.com>
Co-authored-by: minglu7 <1347866672@qq.com>
Co-authored-by: Ponder <ruan.lj@foxmail.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: heyszt <270985384@qq.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: Guangdong Liu <liugddx@gmail.com>
Co-authored-by: Eric Guo <eric.guocz@gmail.com>
Co-authored-by: NeatGuyCoding <15627489+NeatGuyCoding@users.noreply.github.com>
Co-authored-by: XlKsyt <caixuesen@outlook.com>
Co-authored-by: Dhruv Gorasiya <80987415+DhruvGorasiya@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: lyzno1 <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: hj24 <mambahj24@gmail.com>
Co-authored-by: GuanMu <ballmanjq@gmail.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Tonlo <123lzs123@gmail.com>
Co-authored-by: Yusuke Yamada <yamachu.dev@gmail.com>
Co-authored-by: Novice <novice12185727@gmail.com>
Co-authored-by: kenwoodjw <blackxin55+@gmail.com>
Co-authored-by: Ademílson Tonato <ademilsonft@outlook.com>
Co-authored-by: znn <jubinkumarsoni@gmail.com>
Co-authored-by: yangzheli <43645580+yangzheli@users.noreply.github.com>
This commit is contained in:
-LAN-
2025-10-19 21:29:04 +08:00
committed by GitHub
parent 141ca8904a
commit 9a5f214623
60 changed files with 879 additions and 533 deletions

View File

@ -55,3 +55,12 @@ else:
"properties",
}
DOCUMENT_EXTENSIONS: set[str] = convert_to_lower_and_upper_set(_doc_extensions)
COOKIE_NAME_ACCESS_TOKEN = "access_token"
COOKIE_NAME_REFRESH_TOKEN = "refresh_token"
COOKIE_NAME_PASSPORT = "passport"
COOKIE_NAME_CSRF_TOKEN = "csrf_token"
HEADER_NAME_CSRF_TOKEN = "X-CSRF-Token"
HEADER_NAME_APP_CODE = "X-App-Code"
HEADER_NAME_PASSPORT = "X-App-Passport"

View File

@ -15,6 +15,7 @@ from constants.languages import supported_language
from controllers.console import api, console_ns
from controllers.console.wraps import only_edition_cloud
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, InstalledApp, RecommendedApp
@ -24,19 +25,9 @@ def admin_required(view: Callable[P, R]):
if not dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")
auth_header = request.headers.get("Authorization")
if auth_header is None:
auth_token = extract_access_token(request)
if not auth_token:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
if auth_token != dify_config.ADMIN_API_KEY:
raise Unauthorized("API key is invalid.")

View File

@ -1,5 +1,5 @@
import flask_login
from flask import request
from flask import make_response, request
from flask_restx import Resource, reqparse
import services
@ -25,6 +25,16 @@ from controllers.console.wraps import email_password_login_enabled, setup_requir
from events.tenant_event import tenant_was_created
from libs.helper import email, extract_remote_ip
from libs.login import current_account_with_tenant
from libs.token import (
clear_access_token_from_cookie,
clear_csrf_token_from_cookie,
clear_refresh_token_from_cookie,
extract_access_token,
extract_csrf_token,
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
from services.errors.account import AccountRegisterError
@ -89,20 +99,36 @@ class LoginApi(Resource):
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
@console_ns.route("/logout")
class LogoutApi(Resource):
@setup_required
def get(self):
def post(self):
current_user, _ = current_account_with_tenant()
account = current_user
if isinstance(account, flask_login.AnonymousUserMixin):
return {"result": "success"}
AccountService.logout(account=account)
flask_login.logout_user()
return {"result": "success"}
response = make_response({"result": "success"})
else:
AccountService.logout(account=account)
flask_login.logout_user()
response = make_response({"result": "success"})
# Clear cookies on logout
clear_access_token_from_cookie(response)
clear_refresh_token_from_cookie(response)
clear_csrf_token_from_cookie(response)
return response
@console_ns.route("/reset-password")
@ -227,17 +253,46 @@ class EmailCodeLoginApi(Resource):
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
# Set HTTP-only secure cookies for tokens
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
return response
@console_ns.route("/refresh-token")
class RefreshTokenApi(Resource):
def post(self):
parser = reqparse.RequestParser().add_argument("refresh_token", type=str, required=True, location="json")
args = parser.parse_args()
# Get refresh token from cookie instead of request body
refresh_token = request.cookies.get("refresh_token")
if not refresh_token:
return {"result": "fail", "message": "No refresh token provided"}, 401
try:
new_token_pair = AccountService.refresh_token(args["refresh_token"])
return {"result": "success", "data": new_token_pair.model_dump()}
new_token_pair = AccountService.refresh_token(refresh_token)
# Create response with new cookies
response = make_response({"result": "success"})
# Update cookies with new tokens
set_csrf_token_to_cookie(request, response, new_token_pair.csrf_token)
set_access_token_to_cookie(request, response, new_token_pair.access_token)
set_refresh_token_to_cookie(request, response, new_token_pair.refresh_token)
return response
except Exception as e:
return {"result": "fail", "data": str(e)}, 401
return {"result": "fail", "message": str(e)}, 401
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@console_ns.route("/login/status")
class LoginStatus(Resource):
def get(self):
token = extract_access_token(request)
csrf_token = extract_csrf_token(request)
return {"logged_in": bool(token) and bool(csrf_token)}

View File

@ -14,6 +14,11 @@ from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from libs.token import (
set_access_token_to_cookie,
set_csrf_token_to_cookie,
set_refresh_token_to_cookie,
)
from models import Account, AccountStatus
from services.account_service import AccountService, RegisterService, TenantService
from services.billing_service import BillingService
@ -152,9 +157,12 @@ class OAuthCallback(Resource):
ip_address=extract_remote_ip(request),
)
return redirect(
f"{dify_config.CONSOLE_WEB_URL}?access_token={token_pair.access_token}&refresh_token={token_pair.refresh_token}"
)
response = redirect(f"{dify_config.CONSOLE_WEB_URL}")
set_access_token_to_cookie(request, response, token_pair.access_token)
set_refresh_token_to_cookie(request, response, token_pair.refresh_token)
set_csrf_token_to_cookie(request, response, token_pair.csrf_token)
return response
def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) -> Account | None:

View File

@ -15,7 +15,6 @@ from libs.datetime_utils import naive_utc_now
from libs.login import current_account_with_tenant, login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -67,31 +66,26 @@ class InstalledAppsListApi(Resource):
# Pre-filter out apps without setting or with sso_verified
filtered_installed_apps = []
app_id_to_app_code = {}
for installed_app in installed_app_list:
app_id = installed_app["app"].id
webapp_setting = webapp_settings.get(app_id)
if not webapp_setting or webapp_setting.access_mode == "sso_verified":
continue
app_code = AppService.get_app_code_by_id(str(app_id))
app_id_to_app_code[app_id] = app_code
filtered_installed_apps.append(installed_app)
app_codes = list(app_id_to_app_code.values())
# Batch permission check
app_ids = [installed_app["app"].id for installed_app in filtered_installed_apps]
permissions = EnterpriseService.WebAppAuth.batch_is_user_allowed_to_access_webapps(
user_id=user_id,
app_codes=app_codes,
app_ids=app_ids,
)
# Keep only allowed apps
res = []
for installed_app in filtered_installed_apps:
app_id = installed_app["app"].id
app_code = app_id_to_app_code[app_id]
if permissions.get(app_code):
if permissions.get(app_id):
res.append(installed_app)
installed_app_list = res

View File

@ -10,7 +10,6 @@ from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import InstalledApp
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
@ -56,10 +55,9 @@ def user_allowed_to_access_app(view: Callable[Concatenate[InstalledApp, P], R] |
feature = FeatureService.get_system_features()
if feature.webapp_auth.enabled:
app_id = installed_app.app_id
app_code = AppService.get_app_code_by_id(app_id)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=str(current_user.id),
app_code=app_code,
app_id=app_id,
)
if not res:
raise AppAccessDeniedError()

View File

@ -4,12 +4,14 @@ from flask import request
from flask_restx import Resource, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.common import fields
from controllers.web import web_ns
from controllers.web.error import AppUnavailableError
from controllers.web.wraps import WebApiResource
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from libs.passport import PassportService
from libs.token import extract_webapp_passport
from models.model import App, AppMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@ -133,18 +135,19 @@ class AppWebAuthPermission(Resource):
)
def get(self):
user_id = "visitor"
app_code = request.headers.get(HEADER_NAME_APP_CODE)
app_id = request.args.get("appId")
if not app_id or not app_code:
raise ValueError("appId must be provided")
require_permission_check = WebAppAuthService.is_app_require_permission_check(app_id=app_id)
if not require_permission_check:
return {"result": True}
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
tk = extract_webapp_passport(app_code, request)
if not tk:
raise Unauthorized("Access token is missing.")
decoded = PassportService().verify(tk)
user_id = decoded.get("user_id", "visitor")
except Unauthorized:
@ -157,13 +160,7 @@ class AppWebAuthPermission(Resource):
if not features.webapp_auth.enabled:
return {"result": True}
parser = reqparse.RequestParser().add_argument("appId", type=str, required=True, location="args")
args = parser.parse_args()
app_id = args["appId"]
app_code = AppService.get_app_code_by_id(app_id)
res = True
if WebAppAuthService.is_app_require_permission_check(app_id=app_id):
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code)
res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_id)
return {"result": res}

View File

@ -1,7 +1,9 @@
from flask import make_response, request
from flask_restx import Resource, reqparse
from jwt import InvalidTokenError
import services
from configs import dify_config
from controllers.console.auth.error import (
AuthenticationFailedError,
EmailCodeError,
@ -10,9 +12,16 @@ from controllers.console.auth.error import (
from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
from libs.passport import PassportService
from libs.password import valid_password
from libs.token import (
clear_access_token_from_cookie,
extract_access_token,
)
from services.account_service import AccountService
from services.app_service import AppService
from services.webapp_auth_service import WebAppAuthService
@ -52,17 +61,75 @@ class LoginApi(Resource):
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
return {"result": "success", "data": {"access_token": token}}
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response
# class LogoutApi(Resource):
# @setup_required
# def get(self):
# account = cast(Account, flask_login.current_user)
# if isinstance(account, flask_login.AnonymousUserMixin):
# return {"result": "success"}
# flask_login.logout_user()
# return {"result": "success"}
# this api helps frontend to check whether user is authenticated
# TODO: remove in the future. frontend should redirect to login page by catching 401 status
@web_ns.route("/login/status")
class LoginStatusApi(Resource):
@setup_required
@web_ns.doc("web_app_login_status")
@web_ns.doc(description="Check login status")
@web_ns.doc(
responses={
200: "Login status",
401: "Login status",
}
)
def get(self):
app_code = request.args.get("app_code")
token = extract_access_token(request)
if not app_code:
return {
"logged_in": bool(token),
"app_logged_in": False,
}
app_id = AppService.get_app_id_by_code(app_code)
is_public = not dify_config.ENTERPRISE_ENABLED or not WebAppAuthService.is_app_require_permission_check(
app_id=app_id
)
user_logged_in = False
if is_public:
user_logged_in = True
else:
try:
PassportService().verify(token=token)
user_logged_in = True
except Exception:
user_logged_in = False
try:
_ = decode_jwt_token(app_code=app_code)
app_logged_in = True
except Exception:
app_logged_in = False
return {
"logged_in": user_logged_in,
"app_logged_in": app_logged_in,
}
@web_ns.route("/logout")
class LogoutApi(Resource):
@setup_required
@web_ns.doc("web_app_logout")
@web_ns.doc(description="Logout user from web application")
@web_ns.doc(
responses={
200: "Logout successful",
}
)
def post(self):
response = make_response({"result": "success"})
# enterprise SSO sets same site to None in https deployment
# so we need to logout by calling api
clear_access_token_from_cookie(response, samesite="None")
return response
@web_ns.route("/email-code-login")
@ -96,7 +163,6 @@ class EmailCodeLoginSendEmailApi(Resource):
raise AuthenticationFailedError()
else:
token = WebAppAuthService.send_email_code_login_email(account=account, language=language)
return {"result": "success", "data": token}
@ -142,4 +208,6 @@ class EmailCodeLoginApi(Resource):
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": {"access_token": token}}
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@ -1,17 +1,20 @@
import uuid
from datetime import UTC, datetime, timedelta
from flask import request
from flask import make_response, request
from flask_restx import Resource
from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
from constants import HEADER_NAME_APP_CODE
from controllers.web import web_ns
from controllers.web.error import WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService, WebAppAuthType
@ -32,15 +35,15 @@ class PassportResource(Resource):
)
def get(self):
system_features = FeatureService.get_system_features()
app_code = request.headers.get("X-App-Code")
app_code = request.headers.get(HEADER_NAME_APP_CODE)
user_id = request.args.get("user_id")
web_app_access_token = request.args.get("web_app_access_token")
access_token = extract_access_token(request)
if app_code is None:
raise Unauthorized("X-App-Code header is missing.")
app_id = AppService.get_app_id_by_code(app_code)
# exchange token for enterprise logined web user
enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token)
enterprise_user_decoded = decode_enterprise_webapp_user_id(access_token)
if enterprise_user_decoded:
# a web user has already logged in, exchange a token for this app without redirecting to the login page
return exchange_token_for_existing_web_user(
@ -48,7 +51,7 @@ class PassportResource(Resource):
)
if system_features.webapp_auth.enabled:
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if not app_settings or not app_settings.access_mode == "public":
raise WebAppAuthRequiredError()
@ -99,9 +102,12 @@ class PassportResource(Resource):
tk = PassportService().issue(payload)
return {
"access_token": tk,
}
response = make_response(
{
"access_token": tk,
}
)
return response
def decode_enterprise_webapp_user_id(jwt_token: str | None):
@ -189,9 +195,12 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
"exp": exp,
}
token: str = PassportService().issue(payload)
return {
"access_token": token,
}
resp = make_response(
{
"access_token": token,
}
)
return resp
def _exchange_for_public_app_token(app_model, site, token_decoded):
@ -224,9 +233,12 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
tk = PassportService().issue(payload)
return {
"access_token": tk,
}
resp = make_response(
{
"access_token": tk,
}
)
return resp
def generate_session_id():

View File

@ -9,10 +9,13 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from constants import HEADER_NAME_APP_CODE
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_webapp_passport
from models.model import App, EndUser, Site
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings
from services.feature_service import FeatureService
from services.webapp_auth_service import WebAppAuthService
@ -35,22 +38,14 @@ def validate_jwt_token(view: Callable[Concatenate[App, EndUser, P], R] | None =
return decorator
def decode_jwt_token():
def decode_jwt_token(app_code: str | None = None):
system_features = FeatureService.get_system_features()
app_code = str(request.headers.get("X-App-Code"))
if not app_code:
app_code = str(request.headers.get(HEADER_NAME_APP_CODE))
try:
auth_header = request.headers.get("Authorization")
if auth_header is None:
raise Unauthorized("Authorization header is missing.")
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, tk = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
tk = extract_webapp_passport(app_code, request)
if not tk:
raise Unauthorized("App token is missing.")
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
@ -72,7 +67,8 @@ def decode_jwt_token():
app_web_auth_enabled = False
webapp_settings = None
if system_features.webapp_auth.enabled:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code)
app_id = AppService.get_app_id_by_code(app_code)
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id)
if not webapp_settings:
raise NotFound("Web app settings not found.")
app_web_auth_enabled = webapp_settings.access_mode != "public"
@ -87,8 +83,9 @@ def decode_jwt_token():
if system_features.webapp_auth.enabled:
if not app_code:
raise Unauthorized("Please re-login to access the web app.")
app_id = AppService.get_app_id_by_code(app_code)
app_web_auth_enabled = (
EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public"
EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id).access_mode != "public"
)
if app_web_auth_enabled:
raise WebAppAuthRequiredError()
@ -129,7 +126,8 @@ def _validate_user_accessibility(
raise WebAppAuthRequiredError("Web app settings not found.")
if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode):
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code):
app_id = AppService.get_app_id_by_code(app_code)
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_id):
raise WebAppAuthAccessDeniedError()
auth_type = decoded.get("auth_type")

View File

@ -1,4 +1,5 @@
from configs import dify_config
from constants import HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN
from dify_app import DifyApp
@ -16,7 +17,7 @@ def init_app(app: DifyApp):
CORS(
service_api_bp,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(service_api_bp)
@ -25,7 +26,7 @@ def init_app(app: DifyApp):
web_bp,
resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization", "X-App-Code"],
allow_headers=["Content-Type", "Authorization", HEADER_NAME_APP_CODE, HEADER_NAME_CSRF_TOKEN],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -35,7 +36,7 @@ def init_app(app: DifyApp):
console_app_bp,
resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
supports_credentials=True,
allow_headers=["Content-Type", "Authorization"],
allow_headers=["Content-Type", "Authorization", HEADER_NAME_CSRF_TOKEN],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
expose_headers=["X-Version", "X-Env"],
)
@ -43,7 +44,7 @@ def init_app(app: DifyApp):
CORS(
files_bp,
allow_headers=["Content-Type"],
allow_headers=["Content-Type", HEADER_NAME_CSRF_TOKEN],
methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"],
)
app.register_blueprint(files_bp)

View File

@ -9,6 +9,7 @@ from configs import dify_config
from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from libs.token import extract_access_token
from models import Account, Tenant, TenantAccountJoin
from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
@ -24,20 +25,10 @@ def load_user_from_request(request_from_flask_login):
if dify_config.SWAGGER_UI_ENABLED and request.path.endswith((dify_config.SWAGGER_UI_PATH, "/swagger.json")):
return None
auth_header = request.headers.get("Authorization", "")
auth_token: str | None = None
if auth_header:
if " " not in auth_header:
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
auth_scheme, auth_token = auth_header.split(maxsplit=1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
else:
auth_token = request.args.get("_token")
auth_token = extract_access_token(request)
# Check for admin API key authentication first
if dify_config.ADMIN_API_KEY_ENABLE and auth_header:
if dify_config.ADMIN_API_KEY_ENABLE and auth_token:
admin_api_key = dify_config.ADMIN_API_KEY
if admin_api_key and admin_api_key == auth_token:
workspace_id = request.headers.get("X-WORKSPACE-ID")

View File

@ -9,7 +9,9 @@ from werkzeug.exceptions import HTTPException
from werkzeug.http import HTTP_STATUS_CODES
from configs import dify_config
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
from core.errors.error import AppInvokeQuotaExceededError
from libs.token import is_secure
def http_status_message(code):
@ -67,6 +69,19 @@ def register_external_error_handlers(api: Api):
# If you need WWW-Authenticate for 401, add it to headers
if status_code == 401:
headers["WWW-Authenticate"] = 'Bearer realm="api"'
# Check if this is a forced logout error - clear cookies
error_code = getattr(e, "error_code", None)
if error_code == "unauthorized_and_force_logout":
# Add Set-Cookie headers to clear auth cookies
secure = is_secure()
# response is not accessible, so we need to do it ugly
common_part = "Path=/; Expires=Thu, 01 Jan 1970 00:00:00 GMT; HttpOnly"
headers["Set-Cookie"] = [
f'{COOKIE_NAME_ACCESS_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
f'{COOKIE_NAME_CSRF_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
f'{COOKIE_NAME_REFRESH_TOKEN}=""; {common_part}{"; Secure" if secure else ""}; SameSite=Lax',
]
return data, status_code, headers
_ = handle_http_exception

View File

@ -7,6 +7,7 @@ from flask_login.config import EXEMPT_METHODS # type: ignore
from werkzeug.local import LocalProxy
from configs import dify_config
from libs.token import check_csrf_token
from models import Account
from models.model import EndUser
@ -73,6 +74,9 @@ def login_required(func: Callable[P, R]):
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, current_user.id)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view

208
api/libs/token.py Normal file
View File

@ -0,0 +1,208 @@
import logging
import re
from datetime import UTC, datetime, timedelta
from flask import Request
from werkzeug.exceptions import Unauthorized
from werkzeug.wrappers import Response
from configs import dify_config
from constants import (
COOKIE_NAME_ACCESS_TOKEN,
COOKIE_NAME_CSRF_TOKEN,
COOKIE_NAME_PASSPORT,
COOKIE_NAME_REFRESH_TOKEN,
HEADER_NAME_CSRF_TOKEN,
HEADER_NAME_PASSPORT,
)
from libs.passport import PassportService
logger = logging.getLogger(__name__)
CSRF_WHITE_LIST = [
re.compile(r"/console/api/apps/[a-f0-9-]+/workflows/draft"),
]
# server is behind a reverse proxy, so we need to check the url
def is_secure() -> bool:
return dify_config.CONSOLE_WEB_URL.startswith("https") and dify_config.CONSOLE_API_URL.startswith("https")
def _real_cookie_name(cookie_name: str) -> str:
if is_secure():
return "__Host-" + cookie_name
else:
return cookie_name
def _try_extract_from_header(request: Request) -> str | None:
"""
Try to extract access token from header
"""
auth_header = request.headers.get("Authorization")
if auth_header:
if " " not in auth_header:
return None
else:
auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer":
return None
else:
return auth_token
return None
def extract_csrf_token(request: Request) -> str | None:
"""
Try to extract CSRF token from header or cookie.
"""
return request.headers.get(HEADER_NAME_CSRF_TOKEN)
def extract_csrf_token_from_cookie(request: Request) -> str | None:
"""
Try to extract CSRF token from cookie.
"""
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
def extract_access_token(request: Request) -> str | None:
"""
Try to extract access token from cookie, header or params.
Access token is either for console session or webapp passport exchange.
"""
def _try_extract_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
def extract_webapp_passport(app_code: str, request: Request) -> str | None:
"""
Try to extract app token from header or params.
Webapp access token (part of passport) is only used for webapp session.
"""
def _try_extract_passport_token_from_cookie(request: Request) -> str | None:
return request.cookies.get(_real_cookie_name(COOKIE_NAME_PASSPORT + "-" + app_code))
def _try_extract_passport_token_from_header(request: Request) -> str | None:
return request.headers.get(HEADER_NAME_PASSPORT)
ret = _try_extract_passport_token_from_cookie(request) or _try_extract_passport_token_from_header(request)
return ret
def set_access_token_to_cookie(request: Request, response: Response, token: str, samesite: str = "Lax"):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN),
value=token,
httponly=True,
secure=is_secure(),
samesite=samesite,
max_age=int(dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 60),
path="/",
)
def set_refresh_token_to_cookie(request: Request, response: Response, token: str):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_REFRESH_TOKEN),
value=token,
httponly=True,
secure=is_secure(),
samesite="Lax",
max_age=int(60 * 60 * 24 * dify_config.REFRESH_TOKEN_EXPIRE_DAYS),
path="/",
)
def set_csrf_token_to_cookie(request: Request, response: Response, token: str):
response.set_cookie(
_real_cookie_name(COOKIE_NAME_CSRF_TOKEN),
value=token,
httponly=False,
secure=is_secure(),
samesite="Lax",
max_age=int(60 * dify_config.ACCESS_TOKEN_EXPIRE_MINUTES),
path="/",
)
def _clear_cookie(
response: Response,
cookie_name: str,
samesite: str = "Lax",
http_only: bool = True,
):
response.set_cookie(
_real_cookie_name(cookie_name),
"",
expires=0,
path="/",
secure=is_secure(),
httponly=http_only,
samesite=samesite,
)
def clear_access_token_from_cookie(response: Response, samesite: str = "Lax"):
_clear_cookie(response, COOKIE_NAME_ACCESS_TOKEN, samesite)
def clear_refresh_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_REFRESH_TOKEN)
def clear_csrf_token_from_cookie(response: Response):
_clear_cookie(response, COOKIE_NAME_CSRF_TOKEN, http_only=False)
def check_csrf_token(request: Request, user_id: str):
# some apis are sent by beacon, so we need to bypass csrf token check
# since these APIs are post, they are already protected by SameSite: Lax, so csrf is not required.
def _unauthorized():
raise Unauthorized("CSRF token is missing or invalid.")
for pattern in CSRF_WHITE_LIST:
if pattern.match(request.path):
return
csrf_token = extract_csrf_token(request)
csrf_token_from_cookie = extract_csrf_token_from_cookie(request)
if csrf_token != csrf_token_from_cookie:
_unauthorized()
if not csrf_token:
_unauthorized()
verified = {}
try:
verified = PassportService().verify(csrf_token)
except:
_unauthorized()
if verified.get("sub") != user_id:
_unauthorized()
exp: int | None = verified.get("exp")
if not exp:
_unauthorized()
else:
time_now = int(datetime.now().timestamp())
if exp < time_now:
_unauthorized()
def generate_csrf_token(user_id: str) -> str:
exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
payload = {
"exp": int(exp_dt.timestamp()),
"sub": user_id,
}
return PassportService().issue(payload)

View File

@ -22,6 +22,7 @@ from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
from libs.rsa import generate_key_pair
from libs.token import generate_csrf_token
from models.account import (
Account,
AccountIntegrate,
@ -76,6 +77,7 @@ logger = logging.getLogger(__name__)
class TokenPair(BaseModel):
access_token: str
refresh_token: str
csrf_token: str
REFRESH_TOKEN_PREFIX = "refresh_token:"
@ -403,10 +405,11 @@ class AccountService:
access_token = AccountService.get_account_jwt_token(account=account)
refresh_token = _generate_refresh_token()
csrf_token = generate_csrf_token(account.id)
AccountService._store_refresh_token(refresh_token, account.id)
return TokenPair(access_token=access_token, refresh_token=refresh_token)
return TokenPair(access_token=access_token, refresh_token=refresh_token, csrf_token=csrf_token)
@staticmethod
def logout(*, account: Account):
@ -431,8 +434,9 @@ class AccountService:
AccountService._delete_refresh_token(refresh_token, account.id)
AccountService._store_refresh_token(new_refresh_token, account.id)
csrf_token = generate_csrf_token(account.id)
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token)
return TokenPair(access_token=new_access_token, refresh_token=new_refresh_token, csrf_token=csrf_token)
@staticmethod
def load_logged_in_account(*, account_id: str):

View File

@ -46,17 +46,17 @@ class EnterpriseService:
class WebAppAuth:
@classmethod
def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str):
params = {"userId": user_id, "appCode": app_code}
def is_user_allowed_to_access_webapp(cls, user_id: str, app_id: str):
params = {"userId": user_id, "appId": app_id}
data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params)
return data.get("result", False)
@classmethod
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_codes: list[str]):
if not app_codes:
def batch_is_user_allowed_to_access_webapps(cls, user_id: str, app_ids: list[str]):
if not app_ids:
return {}
body = {"userId": user_id, "appCodes": app_codes}
body = {"userId": user_id, "appIds": app_ids}
data = EnterpriseRequest.send_request("POST", "/webapp/permission/batch", json=body)
if not data:
raise ValueError("No data found.")

View File

@ -172,7 +172,8 @@ class WebAppAuthService:
return WebAppAuthType.EXTERNAL
if app_code:
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code)
app_id = AppService.get_app_id_by_code(app_code)
webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
return cls.get_app_auth_type(access_mode=webapp_settings.access_mode)
raise ValueError("Could not determine app authentication type.")

View File

@ -863,13 +863,14 @@ class TestWebAppAuthService:
- Mock service integration
"""
# Arrange: Setup mock for enterprise service
mock_webapp_auth = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
mock_external_service_dependencies["app_service"].get_app_id_by_code.return_value = "mock_app_id"
setting = type("MockWebAppAuth", (), {"access_mode": "sso_verified"})()
mock_external_service_dependencies[
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_code.return_value = mock_webapp_auth
].WebAppAuth.get_app_access_mode_by_id.return_value = setting
# Act: Execute authentication type determination
result = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
result: WebAppAuthType = WebAppAuthService.get_app_auth_type(app_code="mock_app_code")
# Assert: Verify correct result
assert result == WebAppAuthType.EXTERNAL
@ -877,7 +878,7 @@ class TestWebAppAuthService:
# Verify mock service was called correctly
mock_external_service_dependencies[
"enterprise_service"
].WebAppAuth.get_app_access_mode_by_code.assert_called_once_with("mock_app_code")
].WebAppAuth.get_app_access_mode_by_id.assert_called_once_with(app_id="mock_app_id")
def test_get_app_auth_type_no_parameters(self, db_session_with_containers, mock_external_service_dependencies):
"""

View File

@ -179,9 +179,7 @@ class TestOAuthCallback:
oauth_setup["provider"].get_access_token.assert_called_once_with("test_code")
oauth_setup["provider"].get_user_info.assert_called_once_with("access_token")
mock_redirect.assert_called_once_with(
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
)
mock_redirect.assert_called_once_with("http://localhost:3000")
@pytest.mark.parametrize(
("exception", "expected_error"),
@ -224,8 +222,8 @@ class TestOAuthCallback:
# CLOSED status: Currently NOT handled, will proceed to login (security issue)
# This documents actual behavior. See test_defensive_check_for_closed_account_status for details
(
AccountStatus.CLOSED,
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token",
AccountStatus.CLOSED.value,
"http://localhost:3000",
),
],
)
@ -268,6 +266,7 @@ class TestOAuthCallback:
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
@ -299,6 +298,12 @@ class TestOAuthCallback:
mock_account.status = AccountStatus.PENDING
mock_generate_account.return_value = mock_account
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
with app.test_request_context("/auth/oauth/github/callback?code=test_code"):
resource.get("github")
@ -361,6 +366,7 @@ class TestOAuthCallback:
mock_token_pair = MagicMock()
mock_token_pair.access_token = "jwt_access_token"
mock_token_pair.refresh_token = "jwt_refresh_token"
mock_token_pair.csrf_token = "csrf_token"
mock_account_service.login.return_value = mock_token_pair
# Execute OAuth callback
@ -368,9 +374,7 @@ class TestOAuthCallback:
resource.get("github")
# Verify current behavior: login succeeds (this is NOT ideal)
mock_redirect.assert_called_once_with(
"http://localhost:3000?access_token=jwt_access_token&refresh_token=jwt_refresh_token"
)
mock_redirect.assert_called_once_with("http://localhost:3000")
mock_account_service.login.assert_called_once()
# Document expected behavior in comments:

View File

@ -2,7 +2,9 @@ from flask import Blueprint, Flask
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, Unauthorized
from constants import COOKIE_NAME_ACCESS_TOKEN, COOKIE_NAME_CSRF_TOKEN, COOKIE_NAME_REFRESH_TOKEN
from core.errors.error import AppInvokeQuotaExceededError
from libs.exception import BaseHTTPException
from libs.external_api import ExternalApi
@ -120,3 +122,66 @@ def test_external_api_param_mapping_and_quota_and_exc_info_none():
assert res.status_code in (400, 429)
finally:
ext.sys.exc_info = orig_exc_info # type: ignore[assignment]
def test_unauthorized_and_force_logout_clears_cookies():
"""Test that UnauthorizedAndForceLogout error clears auth cookies"""
class UnauthorizedAndForceLogout(BaseHTTPException):
error_code = "unauthorized_and_force_logout"
description = "Unauthorized and force logout."
code = 401
app = Flask(__name__)
bp = Blueprint("test", __name__)
api = ExternalApi(bp)
@api.route("/force-logout")
class ForceLogout(Resource): # type: ignore
def get(self): # type: ignore
raise UnauthorizedAndForceLogout()
app.register_blueprint(bp, url_prefix="/api")
client = app.test_client()
# Set cookies first
client.set_cookie(COOKIE_NAME_ACCESS_TOKEN, "test_access_token")
client.set_cookie(COOKIE_NAME_CSRF_TOKEN, "test_csrf_token")
client.set_cookie(COOKIE_NAME_REFRESH_TOKEN, "test_refresh_token")
# Make request that should trigger cookie clearing
res = client.get("/api/force-logout")
# Verify response
assert res.status_code == 401
data = res.get_json()
assert data["code"] == "unauthorized_and_force_logout"
assert data["status"] == 401
assert "WWW-Authenticate" in res.headers
# Verify Set-Cookie headers are present to clear cookies
set_cookie_headers = res.headers.getlist("Set-Cookie")
assert len(set_cookie_headers) == 3, f"Expected 3 Set-Cookie headers, got {len(set_cookie_headers)}"
# Verify each cookie is being cleared (empty value and expired)
cookie_names_found = set()
for cookie_header in set_cookie_headers:
# Check for cookie names
if COOKIE_NAME_ACCESS_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_ACCESS_TOKEN)
assert '""' in cookie_header or "=" in cookie_header # Empty value
assert "Expires=Thu, 01 Jan 1970" in cookie_header # Expired
elif COOKIE_NAME_CSRF_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_CSRF_TOKEN)
assert '""' in cookie_header or "=" in cookie_header
assert "Expires=Thu, 01 Jan 1970" in cookie_header
elif COOKIE_NAME_REFRESH_TOKEN in cookie_header:
cookie_names_found.add(COOKIE_NAME_REFRESH_TOKEN)
assert '""' in cookie_header or "=" in cookie_header
assert "Expires=Thu, 01 Jan 1970" in cookie_header
# Verify all three cookies are present
assert len(cookie_names_found) == 3
assert COOKIE_NAME_ACCESS_TOKEN in cookie_names_found
assert COOKIE_NAME_CSRF_TOKEN in cookie_names_found
assert COOKIE_NAME_REFRESH_TOKEN in cookie_names_found

View File

@ -19,10 +19,15 @@ class MockUser(UserMixin):
return self._is_authenticated
def mock_csrf_check(*args, **kwargs):
return
class TestLoginRequired:
"""Test cases for login_required decorator."""
@pytest.fixture
@patch("libs.login.check_csrf_token", mock_csrf_check)
def setup_app(self, app: Flask):
"""Set up Flask app with login manager."""
# Initialize login manager
@ -39,6 +44,7 @@ class TestLoginRequired:
return app
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
"""Test that authenticated users can access protected views."""
@ -53,6 +59,7 @@ class TestLoginRequired:
result = protected_view()
assert result == "Protected content"
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
"""Test that unauthenticated users are redirected."""
@ -68,6 +75,7 @@ class TestLoginRequired:
assert result == "Unauthorized"
setup_app.login_manager.unauthorized.assert_called_once()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
"""Test that LOGIN_DISABLED config bypasses authentication."""
@ -87,6 +95,7 @@ class TestLoginRequired:
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_options_request_bypasses_authentication(self, setup_app: Flask):
"""Test that OPTIONS requests are exempt from authentication."""
@ -103,6 +112,7 @@ class TestLoginRequired:
# Ensure unauthorized was not called
setup_app.login_manager.unauthorized.assert_not_called()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_2_compatibility(self, setup_app: Flask):
"""Test Flask 2.x compatibility with ensure_sync."""
@ -120,6 +130,7 @@ class TestLoginRequired:
assert result == "Synced content"
setup_app.ensure_sync.assert_called_once()
@patch("libs.login.check_csrf_token", mock_csrf_check)
def test_flask_1_compatibility(self, setup_app: Flask):
"""Test Flask 1.x compatibility without ensure_sync."""

View File

@ -0,0 +1,23 @@
from constants import COOKIE_NAME_ACCESS_TOKEN
from libs.token import extract_access_token
class MockRequest:
def __init__(self, headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
self.headers: dict[str, str] = headers
self.cookies: dict[str, str] = cookies
self.args: dict[str, str] = args
def test_extract_access_token():
def _mock_request(headers: dict[str, str], cookies: dict[str, str], args: dict[str, str]):
return MockRequest(headers, cookies, args)
test_cases = [
(_mock_request({"Authorization": "Bearer 123"}, {}, {}), "123"),
(_mock_request({}, {COOKIE_NAME_ACCESS_TOKEN: "123"}, {}), "123"),
(_mock_request({}, {}, {}), None),
(_mock_request({"Authorization": "Bearer_aaa 123"}, {}, {}), None),
]
for request, expected in test_cases:
assert extract_access_token(request) == expected # pyright: ignore[reportArgumentType]