From eba0041973ae2c9fb91ad2bc700ee2232539285e Mon Sep 17 00:00:00 2001 From: GareArc Date: Sun, 24 May 2026 21:17:36 -0700 Subject: [PATCH] fix(openapi): close 4 critical OAuth device-flow security findings 1. Host-header injection (sso_initiate / sso_complete): replace request.host_url with dify_config.CONSOLE_API_URL via a _trusted_origin() helper that fails closed when unset. An attacker-controlled Host header on sso-initiate would otherwise be sealed into the signed state envelope, causing the IdP to redirect the victim's EE-signed SSO assertion to evil.com. 2. Unvalidated JWS claim payloads: add ExtSubjectAssertionClaims and ApprovalGrantClaimsPayload pydantic models and route every verified payload through model_validate. A signed-but-malformed blob now returns BadRequest('invalid_sso_assertion') or VerifyError('claim shape invalid') instead of crashing the handler with KeyError / 500. ApprovalGrantClaimsPayload is imported lazily inside verify_approval_grant to break the libs -> controllers cycle. 3. Timing-unsafe CSRF compare in approve_external: replace plain != with secrets.compare_digest. 4. Bearer rate-limit bypass on revoked tokens: move enforce_bearer_rate_limit to fire after sha256_hex but before resolver.resolve, so revoked-token replay is now bounded. Also collapse the two distinct error messages (unknown token prefix vs token unknown or revoked) into a single generic 'invalid_bearer' to remove the prefix-validity oracle. Tests: 4 new unit-test files cover each finding plus one updated test for the new bearer error string. 744 tests pass. --- api/controllers/openapi/_models.py | 18 ++++ api/controllers/openapi/oauth_device_sso.py | 37 +++++--- api/libs/device_flow_security.py | 23 +++-- api/libs/oauth_bearer.py | 6 +- .../openapi/auth/test_step_bearer.py | 2 +- .../openapi/test_oauth_sso_claims.py | 73 ++++++++++++++++ .../openapi/test_oauth_sso_csrf.py | 22 +++++ .../openapi/test_oauth_sso_host_header.py | 75 +++++++++++++++++ .../test_oauth_bearer_rate_limit_ordering.py | 84 +++++++++++++++++++ 9 files changed, 319 insertions(+), 21 deletions(-) create mode 100644 api/tests/unit_tests/controllers/openapi/test_oauth_sso_claims.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_oauth_sso_csrf.py create mode 100644 api/tests/unit_tests/controllers/openapi/test_oauth_sso_host_header.py create mode 100644 api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py diff --git a/api/controllers/openapi/_models.py b/api/controllers/openapi/_models.py index 62d643c30f..128a937549 100644 --- a/api/controllers/openapi/_models.py +++ b/api/controllers/openapi/_models.py @@ -324,3 +324,21 @@ class PermittedExternalAppsListQuery(BaseModel): limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT) mode: AppMode | None = None name: str | None = Field(None, max_length=200) + + +_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$") + + +class ExtSubjectAssertionClaims(BaseModel): + email: str = _EMAIL_FIELD + issuer: str = Field(min_length=1, max_length=255) + user_code: str = Field(min_length=1, max_length=32) + nonce: str = Field(min_length=1, max_length=128) + + +class ApprovalGrantClaimsPayload(BaseModel): + subject_email: str = _EMAIL_FIELD + subject_issuer: str = Field(min_length=1, max_length=255) + user_code: str = Field(min_length=1, max_length=32) + nonce: str = Field(min_length=1, max_length=128) + csrf_token: str = Field(min_length=1, max_length=128) diff --git a/api/controllers/openapi/oauth_device_sso.py b/api/controllers/openapi/oauth_device_sso.py index 08ecce0a38..0218d14330 100644 --- a/api/controllers/openapi/oauth_device_sso.py +++ b/api/controllers/openapi/oauth_device_sso.py @@ -17,6 +17,7 @@ import secrets from dataclasses import dataclass from flask import jsonify, make_response, redirect, request +from pydantic import ValidationError from werkzeug.exceptions import ( BadGateway, BadRequest, @@ -26,7 +27,9 @@ from werkzeug.exceptions import ( Unauthorized, ) +from configs import dify_config from controllers.openapi import bp +from controllers.openapi._models import ExtSubjectAssertionClaims from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import jws @@ -72,6 +75,13 @@ STATE_ENVELOPE_TTL_SECONDS = 15 * 60 _SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete" +def _trusted_origin() -> str: + base = (dify_config.CONSOLE_API_URL or "").rstrip("/") + if not base: + raise BadGateway("console_api_url_unset") + return base + + @bp.route("/oauth/device/sso-initiate", methods=["GET"]) @enterprise_only @rate_limit(LIMIT_SSO_INITIATE_PER_IP) @@ -88,6 +98,7 @@ def sso_initiate(): if state.status is not DeviceFlowStatus.PENDING: raise BadRequest("invalid_user_code") + origin = _trusted_origin() keyset = jws.KeySet.from_shared_secret() signed_state = jws.sign( keyset, @@ -98,7 +109,7 @@ def sso_initiate(): "user_code": user_code, "nonce": secrets.token_urlsafe(16), "return_to": "", - "idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}", + "idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}", }, aud=jws.AUD_STATE_ENVELOPE, ttl_seconds=STATE_ENVELOPE_TTL_SECONDS, @@ -130,15 +141,21 @@ def sso_complete(): keyset = jws.KeySet.from_shared_secret() try: - claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION) + raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION) except jws.VerifyError as e: logger.warning("sso-complete: rejected assertion: %s", e) raise BadRequest("invalid_sso_assertion") from e - if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")): + try: + claims = ExtSubjectAssertionClaims.model_validate(raw_claims) + except ValidationError as e: + logger.warning("sso-complete: claim shape invalid: %s", e) + raise BadRequest("invalid_sso_assertion") from e + + if not consume_sso_assertion_nonce(redis_client, claims.nonce): raise BadRequest("invalid_sso_assertion") - user_code = (claims.get("user_code") or "").strip().upper() + user_code = claims.user_code.strip().upper() store = DeviceFlowRedis(redis_client) found = store.load_by_user_code(user_code) if found is None: @@ -147,20 +164,20 @@ def sso_complete(): if state.status is not DeviceFlowStatus.PENDING: raise Conflict("user_code_not_pending") - if AccountService.has_active_account_with_email(db.session, claims["email"]): + if AccountService.has_active_account_with_email(db.session, claims.email): _emit_external_rejection_audit( state, - _RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]), + _RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer), reason="email_belongs_to_dify_account", ) return redirect("/device?sso_error=email_belongs_to_dify_account", code=302) - iss = request.host_url.rstrip("/") + iss = _trusted_origin() cookie_value, _ = mint_approval_grant( keyset=keyset, iss=iss, - subject_email=claims["email"], - subject_issuer=claims["issuer"], + subject_email=claims.email, + subject_issuer=claims.issuer, user_code=user_code, ) @@ -211,7 +228,7 @@ def approve_external(): enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}") csrf_header = request.headers.get("X-CSRF-Token", "") - if not csrf_header or csrf_header != claims.csrf_token: + if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token): raise Forbidden("csrf_mismatch") data = request.get_json(silent=True) or {} diff --git a/api/libs/device_flow_security.py b/api/libs/device_flow_security.py index d973a0820b..9f4c1f56f6 100644 --- a/api/libs/device_flow_security.py +++ b/api/libs/device_flow_security.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime, timedelta from functools import wraps from flask import Blueprint +from pydantic import ValidationError from werkzeug.exceptions import NotFound from libs import jws @@ -107,14 +108,22 @@ def mint_approval_grant( def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims: """Sig + aud + exp only — nonce consumption is the caller's job.""" - data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT) + # lazy import: breaks libs → controllers cycle + from controllers.openapi._models import ApprovalGrantClaimsPayload + + raw = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT) + try: + parsed = ApprovalGrantClaimsPayload.model_validate(raw) + except ValidationError as e: + raise jws.VerifyError(f"claim shape invalid: {e}") from e + return ApprovalGrantClaims( - subject_email=data["subject_email"], - subject_issuer=data["subject_issuer"], - user_code=data["user_code"], - nonce=data["nonce"], - csrf_token=data["csrf_token"], - expires_at=datetime.fromtimestamp(data["exp"], tz=UTC), + subject_email=parsed.subject_email, + subject_issuer=parsed.subject_issuer, + user_code=parsed.user_code, + nonce=parsed.nonce, + csrf_token=parsed.csrf_token, + expires_at=datetime.fromtimestamp(raw["exp"], tz=UTC), ) diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py index ccbda1fd35..6e8678eca0 100644 --- a/api/libs/oauth_bearer.py +++ b/api/libs/oauth_bearer.py @@ -277,12 +277,12 @@ class BearerAuthenticator: """ kind = self._registry.find(token) if kind is None: - raise InvalidBearerError("unknown token prefix") + raise InvalidBearerError("invalid_bearer") token_hash = sha256_hex(token) + enforce_bearer_rate_limit(token_hash) row = kind.resolver.resolve(token_hash) if row is None: - raise InvalidBearerError("token unknown or revoked") - enforce_bearer_rate_limit(token_hash) + raise InvalidBearerError("invalid_bearer") return AuthContext( subject_type=kind.subject_type, subject_email=row.subject_email, diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py index 4ab8596fe2..329f158f30 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py @@ -30,7 +30,7 @@ def test_bearer_check_rejects_missing_header(): @patch("controllers.openapi.auth.steps.get_authenticator") def test_bearer_check_rejects_unknown_prefix(get_auth): - get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix") + get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer") app = Flask(__name__) with app.test_request_context(), pytest.raises(Unauthorized): BearerCheck()(_ctx("xxx_abc")) diff --git a/api/tests/unit_tests/controllers/openapi/test_oauth_sso_claims.py b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_claims.py new file mode 100644 index 0000000000..9eb4cf01f4 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_claims.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.openapi import bp as openapi_bp + + +@pytest.fixture +def app() -> Flask: + a = Flask(__name__) + a.config["TESTING"] = True + a.register_blueprint(openapi_bp) + return a + + +def _ee_features(): + from services.feature_service import LicenseStatus + + m = MagicMock() + m.license.status = LicenseStatus.ACTIVE + return m + + +@patch("controllers.openapi.oauth_device_sso.jws") +@patch("libs.device_flow_security.FeatureService.get_system_features") +def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Flask): + ee_feat.return_value = _ee_features() + jws_mod.verify.return_value = {"issuer": "https://idp.example", "user_code": "ABCD-EFGH", "nonce": "n"} + jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud" + jws_mod.KeySet.from_shared_secret.return_value = object() + jws_mod.VerifyError = Exception + + client = app.test_client() + resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob") + assert resp.status_code == 400, resp.data + + +@patch("controllers.openapi.oauth_device_sso.jws") +@patch("libs.device_flow_security.FeatureService.get_system_features") +def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flask): + ee_feat.return_value = _ee_features() + jws_mod.verify.return_value = {"email": "x@y.com", "issuer": "", "user_code": "ABCD-EFGH", "nonce": "n"} + jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud" + jws_mod.KeySet.from_shared_secret.return_value = object() + jws_mod.VerifyError = Exception + + client = app.test_client() + resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob") + assert resp.status_code == 400 + + +def test_verify_approval_grant_raises_on_missing_field(): + from libs import device_flow_security + from libs import jws as jws_mod + + class _FakeKeyset: + active_kid = "k" + + def lookup(self, kid): + return b"secret" + + keyset = _FakeKeyset() + incomplete = jws_mod.sign( + keyset, + payload={"subject_email": "x@y.com", "subject_issuer": "i", "user_code": "ABCD-EFGH", "nonce": "n"}, + aud=jws_mod.AUD_APPROVAL_GRANT, + ttl_seconds=60, + ) + with pytest.raises(jws_mod.VerifyError): + device_flow_security.verify_approval_grant(keyset, incomplete) diff --git a/api/tests/unit_tests/controllers/openapi/test_oauth_sso_csrf.py b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_csrf.py new file mode 100644 index 0000000000..e31a308506 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_csrf.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +import ast +from pathlib import Path + + +def _repo_root() -> Path: + for parent in Path(__file__).resolve().parents: + if (parent / "api" / "pyproject.toml").exists(): + return parent + raise RuntimeError("repo root not found") + + +def test_approve_external_uses_compare_digest_for_csrf(): + src = (_repo_root() / "api" / "controllers" / "openapi" / "oauth_device_sso.py").read_text() + tree = ast.parse(src) + + fn = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "approve_external") + fn_src = ast.unparse(fn) + + assert "compare_digest" in fn_src, "approve_external must call secrets.compare_digest for CSRF" + assert "csrf_header != claims.csrf_token" not in fn_src, "approve_external must not use plain != on csrf_token" diff --git a/api/tests/unit_tests/controllers/openapi/test_oauth_sso_host_header.py b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_host_header.py new file mode 100644 index 0000000000..2919cfd353 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/test_oauth_sso_host_header.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask + +from controllers.openapi import bp as openapi_bp + + +@pytest.fixture +def app() -> Flask: + a = Flask(__name__) + a.config["TESTING"] = True + a.register_blueprint(openapi_bp) + return a + + +def _ee_features(): + from services.feature_service import LicenseStatus + + m = MagicMock() + m.license.status = LicenseStatus.ACTIVE + return m + + +@patch("controllers.openapi.oauth_device_sso.EnterpriseService") +@patch("controllers.openapi.oauth_device_sso.jws") +@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis") +@patch("controllers.openapi.oauth_device_sso.dify_config") +@patch("libs.device_flow_security.FeatureService.get_system_features") +@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False)) +@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock()) +def test_idp_callback_url_uses_console_api_url_not_host_header(ee_feat, cfg, redis_cls, jws_mod, ent, app: Flask): + ee_feat.return_value = _ee_features() + cfg.CONSOLE_API_URL = "https://api.dify.example" + state = MagicMock() + from services.oauth_device_flow import DeviceFlowStatus + + state.status = DeviceFlowStatus.PENDING + redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state) + jws_mod.KeySet.from_shared_secret.return_value = MagicMock() + jws_mod.sign.return_value = "signed-state" + jws_mod.AUD_STATE_ENVELOPE = "aud" + ent.initiate_device_flow_sso.return_value = {"url": "https://idp.example/auth"} + + client = app.test_client() + client.get( + "/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH", + headers={"Host": "evil.com"}, + ) + + args, kwargs = jws_mod.sign.call_args + signed_payload = args[1] if len(args) > 1 else kwargs["payload"] + assert signed_payload["idp_callback_url"].startswith("https://api.dify.example") + assert "evil.com" not in signed_payload["idp_callback_url"] + + +@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis") +@patch("controllers.openapi.oauth_device_sso.dify_config") +@patch("libs.device_flow_security.FeatureService.get_system_features") +@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False)) +@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock()) +def test_sso_initiate_fails_closed_when_console_api_url_unset(ee_feat, cfg, redis_cls, app: Flask): + ee_feat.return_value = _ee_features() + cfg.CONSOLE_API_URL = "" + from services.oauth_device_flow import DeviceFlowStatus + + state = MagicMock() + state.status = DeviceFlowStatus.PENDING + redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state) + + client = app.test_client() + resp = client.get("/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH") + assert resp.status_code == 502 diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py new file mode 100644 index 0000000000..dd4304ccb1 --- /dev/null +++ b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from libs.oauth_bearer import ( + BearerAuthenticator, + InvalidBearerError, + Scope, + SubjectType, + TokenKind, + TokenKindRegistry, +) + + +def _registry_with_resolver(resolver) -> TokenKindRegistry: + return TokenKindRegistry( + [ + TokenKind( + prefix="dfoa_", + subject_type=SubjectType.ACCOUNT, + scopes=frozenset({Scope.FULL}), + source="oauth_account", + resolver=resolver, + ) + ] + ) + + +@patch("libs.oauth_bearer.enforce_bearer_rate_limit") +def test_rate_limit_called_on_unknown_revoked_token(rl): + resolver = MagicMock() + resolver.resolve.return_value = None + auth = BearerAuthenticator(_registry_with_resolver(resolver)) + + with pytest.raises(InvalidBearerError): + auth.authenticate("dfoa_revokedtoken123") + + rl.assert_called_once() + resolver.resolve.assert_called_once() + + +@patch("libs.oauth_bearer.enforce_bearer_rate_limit") +def test_rate_limit_called_before_resolve(rl): + call_order: list[str] = [] + rl.side_effect = lambda _h: call_order.append("rl") + resolver = MagicMock() + resolver.resolve.side_effect = lambda _h: call_order.append("resolve") or None + auth = BearerAuthenticator(_registry_with_resolver(resolver)) + + with pytest.raises(InvalidBearerError): + auth.authenticate("dfoa_xyz") + + assert call_order == ["rl", "resolve"], f"expected rl before resolve, got {call_order}" + + +def test_unknown_prefix_raises_generic_invalid_bearer(): + auth = BearerAuthenticator( + TokenKindRegistry( + [ + TokenKind( + prefix="dfoa_", + subject_type=SubjectType.ACCOUNT, + scopes=frozenset({Scope.FULL}), + source="oauth_account", + resolver=MagicMock(), + ) + ] + ) + ) + with pytest.raises(InvalidBearerError) as exc: + auth.authenticate("zzz_xyz") + assert str(exc.value) == "invalid_bearer" + + +@patch("libs.oauth_bearer.enforce_bearer_rate_limit") +def test_revoked_token_raises_generic_invalid_bearer(rl): + resolver = MagicMock() + resolver.resolve.return_value = None + auth = BearerAuthenticator(_registry_with_resolver(resolver)) + with pytest.raises(InvalidBearerError) as exc: + auth.authenticate("dfoa_revoked") + assert str(exc.value) == "invalid_bearer"