fix(openapi): close 4 critical OAuth device-flow security findings

1. Host-header injection (sso_initiate / sso_complete): replace
   request.host_url with dify_config.CONSOLE_API_URL via a
   _trusted_origin() helper that fails closed when unset. An
   attacker-controlled Host header on sso-initiate would otherwise
   be sealed into the signed state envelope, causing the IdP to
   redirect the victim's EE-signed SSO assertion to evil.com.

2. Unvalidated JWS claim payloads: add ExtSubjectAssertionClaims
   and ApprovalGrantClaimsPayload pydantic models and route every
   verified payload through model_validate. A signed-but-malformed
   blob now returns BadRequest('invalid_sso_assertion') or
   VerifyError('claim shape invalid') instead of crashing the
   handler with KeyError / 500. ApprovalGrantClaimsPayload is
   imported lazily inside verify_approval_grant to break the
   libs -> controllers cycle.

3. Timing-unsafe CSRF compare in approve_external: replace plain
   != with secrets.compare_digest.

4. Bearer rate-limit bypass on revoked tokens: move
   enforce_bearer_rate_limit to fire after sha256_hex but before
   resolver.resolve, so revoked-token replay is now bounded. Also
   collapse the two distinct error messages (unknown token prefix
   vs token unknown or revoked) into a single generic
   'invalid_bearer' to remove the prefix-validity oracle.

Tests: 4 new unit-test files cover each finding plus one updated
test for the new bearer error string. 744 tests pass.
This commit is contained in:
GareArc
2026-05-24 21:17:36 -07:00
parent da10ea017a
commit eba0041973
9 changed files with 319 additions and 21 deletions

View File

@ -324,3 +324,21 @@ class PermittedExternalAppsListQuery(BaseModel):
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
mode: AppMode | None = None
name: str | None = Field(None, max_length=200)
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
class ExtSubjectAssertionClaims(BaseModel):
email: str = _EMAIL_FIELD
issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
class ApprovalGrantClaimsPayload(BaseModel):
subject_email: str = _EMAIL_FIELD
subject_issuer: str = Field(min_length=1, max_length=255)
user_code: str = Field(min_length=1, max_length=32)
nonce: str = Field(min_length=1, max_length=128)
csrf_token: str = Field(min_length=1, max_length=128)

View File

@ -17,6 +17,7 @@ import secrets
from dataclasses import dataclass
from flask import jsonify, make_response, redirect, request
from pydantic import ValidationError
from werkzeug.exceptions import (
BadGateway,
BadRequest,
@ -26,7 +27,9 @@ from werkzeug.exceptions import (
Unauthorized,
)
from configs import dify_config
from controllers.openapi import bp
from controllers.openapi._models import ExtSubjectAssertionClaims
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import jws
@ -72,6 +75,13 @@ STATE_ENVELOPE_TTL_SECONDS = 15 * 60
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
def _trusted_origin() -> str:
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
if not base:
raise BadGateway("console_api_url_unset")
return base
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
@enterprise_only
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
@ -88,6 +98,7 @@ def sso_initiate():
if state.status is not DeviceFlowStatus.PENDING:
raise BadRequest("invalid_user_code")
origin = _trusted_origin()
keyset = jws.KeySet.from_shared_secret()
signed_state = jws.sign(
keyset,
@ -98,7 +109,7 @@ def sso_initiate():
"user_code": user_code,
"nonce": secrets.token_urlsafe(16),
"return_to": "",
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
},
aud=jws.AUD_STATE_ENVELOPE,
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
@ -130,15 +141,21 @@ def sso_complete():
keyset = jws.KeySet.from_shared_secret()
try:
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
except jws.VerifyError as e:
logger.warning("sso-complete: rejected assertion: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
try:
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
except ValidationError as e:
logger.warning("sso-complete: claim shape invalid: %s", e)
raise BadRequest("invalid_sso_assertion") from e
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
raise BadRequest("invalid_sso_assertion")
user_code = (claims.get("user_code") or "").strip().upper()
user_code = claims.user_code.strip().upper()
store = DeviceFlowRedis(redis_client)
found = store.load_by_user_code(user_code)
if found is None:
@ -147,20 +164,20 @@ def sso_complete():
if state.status is not DeviceFlowStatus.PENDING:
raise Conflict("user_code_not_pending")
if AccountService.has_active_account_with_email(db.session, claims["email"]):
if AccountService.has_active_account_with_email(db.session, claims.email):
_emit_external_rejection_audit(
state,
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
reason="email_belongs_to_dify_account",
)
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
iss = request.host_url.rstrip("/")
iss = _trusted_origin()
cookie_value, _ = mint_approval_grant(
keyset=keyset,
iss=iss,
subject_email=claims["email"],
subject_issuer=claims["issuer"],
subject_email=claims.email,
subject_issuer=claims.issuer,
user_code=user_code,
)
@ -211,7 +228,7 @@ def approve_external():
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
csrf_header = request.headers.get("X-CSRF-Token", "")
if not csrf_header or csrf_header != claims.csrf_token:
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
raise Forbidden("csrf_mismatch")
data = request.get_json(silent=True) or {}

View File

@ -12,6 +12,7 @@ from datetime import UTC, datetime, timedelta
from functools import wraps
from flask import Blueprint
from pydantic import ValidationError
from werkzeug.exceptions import NotFound
from libs import jws
@ -107,14 +108,22 @@ def mint_approval_grant(
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
"""Sig + aud + exp only — nonce consumption is the caller's job."""
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
# lazy import: breaks libs → controllers cycle
from controllers.openapi._models import ApprovalGrantClaimsPayload
raw = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
try:
parsed = ApprovalGrantClaimsPayload.model_validate(raw)
except ValidationError as e:
raise jws.VerifyError(f"claim shape invalid: {e}") from e
return ApprovalGrantClaims(
subject_email=data["subject_email"],
subject_issuer=data["subject_issuer"],
user_code=data["user_code"],
nonce=data["nonce"],
csrf_token=data["csrf_token"],
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
subject_email=parsed.subject_email,
subject_issuer=parsed.subject_issuer,
user_code=parsed.user_code,
nonce=parsed.nonce,
csrf_token=parsed.csrf_token,
expires_at=datetime.fromtimestamp(raw["exp"], tz=UTC),
)

View File

@ -277,12 +277,12 @@ class BearerAuthenticator:
"""
kind = self._registry.find(token)
if kind is None:
raise InvalidBearerError("unknown token prefix")
raise InvalidBearerError("invalid_bearer")
token_hash = sha256_hex(token)
enforce_bearer_rate_limit(token_hash)
row = kind.resolver.resolve(token_hash)
if row is None:
raise InvalidBearerError("token unknown or revoked")
enforce_bearer_rate_limit(token_hash)
raise InvalidBearerError("invalid_bearer")
return AuthContext(
subject_type=kind.subject_type,
subject_email=row.subject_email,

View File

@ -30,7 +30,7 @@ def test_bearer_check_rejects_missing_header():
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx("xxx_abc"))

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
@pytest.fixture
def app() -> Flask:
a = Flask(__name__)
a.config["TESTING"] = True
a.register_blueprint(openapi_bp)
return a
def _ee_features():
from services.feature_service import LicenseStatus
m = MagicMock()
m.license.status = LicenseStatus.ACTIVE
return m
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Flask):
ee_feat.return_value = _ee_features()
jws_mod.verify.return_value = {"issuer": "https://idp.example", "user_code": "ABCD-EFGH", "nonce": "n"}
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
jws_mod.KeySet.from_shared_secret.return_value = object()
jws_mod.VerifyError = Exception
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400, resp.data
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flask):
ee_feat.return_value = _ee_features()
jws_mod.verify.return_value = {"email": "x@y.com", "issuer": "", "user_code": "ABCD-EFGH", "nonce": "n"}
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
jws_mod.KeySet.from_shared_secret.return_value = object()
jws_mod.VerifyError = Exception
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400
def test_verify_approval_grant_raises_on_missing_field():
from libs import device_flow_security
from libs import jws as jws_mod
class _FakeKeyset:
active_kid = "k"
def lookup(self, kid):
return b"secret"
keyset = _FakeKeyset()
incomplete = jws_mod.sign(
keyset,
payload={"subject_email": "x@y.com", "subject_issuer": "i", "user_code": "ABCD-EFGH", "nonce": "n"},
aud=jws_mod.AUD_APPROVAL_GRANT,
ttl_seconds=60,
)
with pytest.raises(jws_mod.VerifyError):
device_flow_security.verify_approval_grant(keyset, incomplete)

View File

@ -0,0 +1,22 @@
from __future__ import annotations
import ast
from pathlib import Path
def _repo_root() -> Path:
for parent in Path(__file__).resolve().parents:
if (parent / "api" / "pyproject.toml").exists():
return parent
raise RuntimeError("repo root not found")
def test_approve_external_uses_compare_digest_for_csrf():
src = (_repo_root() / "api" / "controllers" / "openapi" / "oauth_device_sso.py").read_text()
tree = ast.parse(src)
fn = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "approve_external")
fn_src = ast.unparse(fn)
assert "compare_digest" in fn_src, "approve_external must call secrets.compare_digest for CSRF"
assert "csrf_header != claims.csrf_token" not in fn_src, "approve_external must not use plain != on csrf_token"

View File

@ -0,0 +1,75 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
@pytest.fixture
def app() -> Flask:
a = Flask(__name__)
a.config["TESTING"] = True
a.register_blueprint(openapi_bp)
return a
def _ee_features():
from services.feature_service import LicenseStatus
m = MagicMock()
m.license.status = LicenseStatus.ACTIVE
return m
@patch("controllers.openapi.oauth_device_sso.EnterpriseService")
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
@patch("controllers.openapi.oauth_device_sso.dify_config")
@patch("libs.device_flow_security.FeatureService.get_system_features")
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
def test_idp_callback_url_uses_console_api_url_not_host_header(ee_feat, cfg, redis_cls, jws_mod, ent, app: Flask):
ee_feat.return_value = _ee_features()
cfg.CONSOLE_API_URL = "https://api.dify.example"
state = MagicMock()
from services.oauth_device_flow import DeviceFlowStatus
state.status = DeviceFlowStatus.PENDING
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
jws_mod.KeySet.from_shared_secret.return_value = MagicMock()
jws_mod.sign.return_value = "signed-state"
jws_mod.AUD_STATE_ENVELOPE = "aud"
ent.initiate_device_flow_sso.return_value = {"url": "https://idp.example/auth"}
client = app.test_client()
client.get(
"/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH",
headers={"Host": "evil.com"},
)
args, kwargs = jws_mod.sign.call_args
signed_payload = args[1] if len(args) > 1 else kwargs["payload"]
assert signed_payload["idp_callback_url"].startswith("https://api.dify.example")
assert "evil.com" not in signed_payload["idp_callback_url"]
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
@patch("controllers.openapi.oauth_device_sso.dify_config")
@patch("libs.device_flow_security.FeatureService.get_system_features")
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
def test_sso_initiate_fails_closed_when_console_api_url_unset(ee_feat, cfg, redis_cls, app: Flask):
ee_feat.return_value = _ee_features()
cfg.CONSOLE_API_URL = ""
from services.oauth_device_flow import DeviceFlowStatus
state = MagicMock()
state.status = DeviceFlowStatus.PENDING
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH")
assert resp.status_code == 502

View File

@ -0,0 +1,84 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from libs.oauth_bearer import (
BearerAuthenticator,
InvalidBearerError,
Scope,
SubjectType,
TokenKind,
TokenKindRegistry,
)
def _registry_with_resolver(resolver) -> TokenKindRegistry:
return TokenKindRegistry(
[
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
source="oauth_account",
resolver=resolver,
)
]
)
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
def test_rate_limit_called_on_unknown_revoked_token(rl):
resolver = MagicMock()
resolver.resolve.return_value = None
auth = BearerAuthenticator(_registry_with_resolver(resolver))
with pytest.raises(InvalidBearerError):
auth.authenticate("dfoa_revokedtoken123")
rl.assert_called_once()
resolver.resolve.assert_called_once()
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
def test_rate_limit_called_before_resolve(rl):
call_order: list[str] = []
rl.side_effect = lambda _h: call_order.append("rl")
resolver = MagicMock()
resolver.resolve.side_effect = lambda _h: call_order.append("resolve") or None
auth = BearerAuthenticator(_registry_with_resolver(resolver))
with pytest.raises(InvalidBearerError):
auth.authenticate("dfoa_xyz")
assert call_order == ["rl", "resolve"], f"expected rl before resolve, got {call_order}"
def test_unknown_prefix_raises_generic_invalid_bearer():
auth = BearerAuthenticator(
TokenKindRegistry(
[
TokenKind(
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
source="oauth_account",
resolver=MagicMock(),
)
]
)
)
with pytest.raises(InvalidBearerError) as exc:
auth.authenticate("zzz_xyz")
assert str(exc.value) == "invalid_bearer"
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
def test_revoked_token_raises_generic_invalid_bearer(rl):
resolver = MagicMock()
resolver.resolve.return_value = None
auth = BearerAuthenticator(_registry_with_resolver(resolver))
with pytest.raises(InvalidBearerError) as exc:
auth.authenticate("dfoa_revoked")
assert str(exc.value) == "invalid_bearer"