mirror of
https://github.com/langgenius/dify.git
synced 2026-05-31 06:06:20 +08:00
fix(openapi): close 4 critical OAuth device-flow security findings
1. Host-header injection (sso_initiate / sso_complete): replace
request.host_url with dify_config.CONSOLE_API_URL via a
_trusted_origin() helper that fails closed when unset. An
attacker-controlled Host header on sso-initiate would otherwise
be sealed into the signed state envelope, causing the IdP to
redirect the victim's EE-signed SSO assertion to evil.com.
2. Unvalidated JWS claim payloads: add ExtSubjectAssertionClaims
and ApprovalGrantClaimsPayload pydantic models and route every
verified payload through model_validate. A signed-but-malformed
blob now returns BadRequest('invalid_sso_assertion') or
VerifyError('claim shape invalid') instead of crashing the
handler with KeyError / 500. ApprovalGrantClaimsPayload is
imported lazily inside verify_approval_grant to break the
libs -> controllers cycle.
3. Timing-unsafe CSRF compare in approve_external: replace plain
!= with secrets.compare_digest.
4. Bearer rate-limit bypass on revoked tokens: move
enforce_bearer_rate_limit to fire after sha256_hex but before
resolver.resolve, so revoked-token replay is now bounded. Also
collapse the two distinct error messages (unknown token prefix
vs token unknown or revoked) into a single generic
'invalid_bearer' to remove the prefix-validity oracle.
Tests: 4 new unit-test files cover each finding plus one updated
test for the new bearer error string. 744 tests pass.
This commit is contained in:
@ -324,3 +324,21 @@ class PermittedExternalAppsListQuery(BaseModel):
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
_EMAIL_FIELD = Field(min_length=3, max_length=320, pattern=r"^[^@\s]+@[^@\s]+$")
|
||||
|
||||
|
||||
class ExtSubjectAssertionClaims(BaseModel):
|
||||
email: str = _EMAIL_FIELD
|
||||
issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
|
||||
|
||||
class ApprovalGrantClaimsPayload(BaseModel):
|
||||
subject_email: str = _EMAIL_FIELD
|
||||
subject_issuer: str = Field(min_length=1, max_length=255)
|
||||
user_code: str = Field(min_length=1, max_length=32)
|
||||
nonce: str = Field(min_length=1, max_length=128)
|
||||
csrf_token: str = Field(min_length=1, max_length=128)
|
||||
|
||||
@ -17,6 +17,7 @@ import secrets
|
||||
from dataclasses import dataclass
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
@ -26,7 +27,9 @@ from werkzeug.exceptions import (
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi import bp
|
||||
from controllers.openapi._models import ExtSubjectAssertionClaims
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
@ -72,6 +75,13 @@ STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
def _trusted_origin() -> str:
|
||||
base = (dify_config.CONSOLE_API_URL or "").rstrip("/")
|
||||
if not base:
|
||||
raise BadGateway("console_api_url_unset")
|
||||
return base
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
@ -88,6 +98,7 @@ def sso_initiate():
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
origin = _trusted_origin()
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
@ -98,7 +109,7 @@ def sso_initiate():
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
"idp_callback_url": f"{origin}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
@ -130,15 +141,21 @@ def sso_complete():
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
raw_claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
try:
|
||||
claims = ExtSubjectAssertionClaims.model_validate(raw_claims)
|
||||
except ValidationError as e:
|
||||
logger.warning("sso-complete: claim shape invalid: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.nonce):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
user_code = claims.user_code.strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
@ -147,20 +164,20 @@ def sso_complete():
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if AccountService.has_active_account_with_email(db.session, claims["email"]):
|
||||
if AccountService.has_active_account_with_email(db.session, claims.email):
|
||||
_emit_external_rejection_audit(
|
||||
state,
|
||||
_RejectedClaims(subject_email=claims["email"], subject_issuer=claims["issuer"]),
|
||||
_RejectedClaims(subject_email=claims.email, subject_issuer=claims.issuer),
|
||||
reason="email_belongs_to_dify_account",
|
||||
)
|
||||
return redirect("/device?sso_error=email_belongs_to_dify_account", code=302)
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
iss = _trusted_origin()
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
subject_email=claims.email,
|
||||
subject_issuer=claims.issuer,
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
@ -211,7 +228,7 @@ def approve_external():
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
if not csrf_header or not secrets.compare_digest(csrf_header, claims.csrf_token):
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
|
||||
@ -12,6 +12,7 @@ from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
|
||||
from flask import Blueprint
|
||||
from pydantic import ValidationError
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
@ -107,14 +108,22 @@ def mint_approval_grant(
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
# lazy import: breaks libs → controllers cycle
|
||||
from controllers.openapi._models import ApprovalGrantClaimsPayload
|
||||
|
||||
raw = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
try:
|
||||
parsed = ApprovalGrantClaimsPayload.model_validate(raw)
|
||||
except ValidationError as e:
|
||||
raise jws.VerifyError(f"claim shape invalid: {e}") from e
|
||||
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
user_code=data["user_code"],
|
||||
nonce=data["nonce"],
|
||||
csrf_token=data["csrf_token"],
|
||||
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||
subject_email=parsed.subject_email,
|
||||
subject_issuer=parsed.subject_issuer,
|
||||
user_code=parsed.user_code,
|
||||
nonce=parsed.nonce,
|
||||
csrf_token=parsed.csrf_token,
|
||||
expires_at=datetime.fromtimestamp(raw["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -277,12 +277,12 @@ class BearerAuthenticator:
|
||||
"""
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearerError("unknown token prefix")
|
||||
raise InvalidBearerError("invalid_bearer")
|
||||
token_hash = sha256_hex(token)
|
||||
enforce_bearer_rate_limit(token_hash)
|
||||
row = kind.resolver.resolve(token_hash)
|
||||
if row is None:
|
||||
raise InvalidBearerError("token unknown or revoked")
|
||||
enforce_bearer_rate_limit(token_hash)
|
||||
raise InvalidBearerError("invalid_bearer")
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
|
||||
@ -30,7 +30,7 @@ def test_bearer_check_rejects_missing_header():
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx("xxx_abc"))
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
a = Flask(__name__)
|
||||
a.config["TESTING"] = True
|
||||
a.register_blueprint(openapi_bp)
|
||||
return a
|
||||
|
||||
|
||||
def _ee_features():
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
m = MagicMock()
|
||||
m.license.status = LicenseStatus.ACTIVE
|
||||
return m
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
jws_mod.verify.return_value = {"issuer": "https://idp.example", "user_code": "ABCD-EFGH", "nonce": "n"}
|
||||
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
|
||||
jws_mod.KeySet.from_shared_secret.return_value = object()
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400, resp.data
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
jws_mod.verify.return_value = {"email": "x@y.com", "issuer": "", "user_code": "ABCD-EFGH", "nonce": "n"}
|
||||
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
|
||||
jws_mod.KeySet.from_shared_secret.return_value = object()
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_verify_approval_grant_raises_on_missing_field():
|
||||
from libs import device_flow_security
|
||||
from libs import jws as jws_mod
|
||||
|
||||
class _FakeKeyset:
|
||||
active_kid = "k"
|
||||
|
||||
def lookup(self, kid):
|
||||
return b"secret"
|
||||
|
||||
keyset = _FakeKeyset()
|
||||
incomplete = jws_mod.sign(
|
||||
keyset,
|
||||
payload={"subject_email": "x@y.com", "subject_issuer": "i", "user_code": "ABCD-EFGH", "nonce": "n"},
|
||||
aud=jws_mod.AUD_APPROVAL_GRANT,
|
||||
ttl_seconds=60,
|
||||
)
|
||||
with pytest.raises(jws_mod.VerifyError):
|
||||
device_flow_security.verify_approval_grant(keyset, incomplete)
|
||||
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
for parent in Path(__file__).resolve().parents:
|
||||
if (parent / "api" / "pyproject.toml").exists():
|
||||
return parent
|
||||
raise RuntimeError("repo root not found")
|
||||
|
||||
|
||||
def test_approve_external_uses_compare_digest_for_csrf():
|
||||
src = (_repo_root() / "api" / "controllers" / "openapi" / "oauth_device_sso.py").read_text()
|
||||
tree = ast.parse(src)
|
||||
|
||||
fn = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "approve_external")
|
||||
fn_src = ast.unparse(fn)
|
||||
|
||||
assert "compare_digest" in fn_src, "approve_external must call secrets.compare_digest for CSRF"
|
||||
assert "csrf_header != claims.csrf_token" not in fn_src, "approve_external must not use plain != on csrf_token"
|
||||
@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
a = Flask(__name__)
|
||||
a.config["TESTING"] = True
|
||||
a.register_blueprint(openapi_bp)
|
||||
return a
|
||||
|
||||
|
||||
def _ee_features():
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
m = MagicMock()
|
||||
m.license.status = LicenseStatus.ACTIVE
|
||||
return m
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.EnterpriseService")
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
|
||||
@patch("controllers.openapi.oauth_device_sso.dify_config")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
|
||||
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
|
||||
def test_idp_callback_url_uses_console_api_url_not_host_header(ee_feat, cfg, redis_cls, jws_mod, ent, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
cfg.CONSOLE_API_URL = "https://api.dify.example"
|
||||
state = MagicMock()
|
||||
from services.oauth_device_flow import DeviceFlowStatus
|
||||
|
||||
state.status = DeviceFlowStatus.PENDING
|
||||
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
|
||||
jws_mod.KeySet.from_shared_secret.return_value = MagicMock()
|
||||
jws_mod.sign.return_value = "signed-state"
|
||||
jws_mod.AUD_STATE_ENVELOPE = "aud"
|
||||
ent.initiate_device_flow_sso.return_value = {"url": "https://idp.example/auth"}
|
||||
|
||||
client = app.test_client()
|
||||
client.get(
|
||||
"/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH",
|
||||
headers={"Host": "evil.com"},
|
||||
)
|
||||
|
||||
args, kwargs = jws_mod.sign.call_args
|
||||
signed_payload = args[1] if len(args) > 1 else kwargs["payload"]
|
||||
assert signed_payload["idp_callback_url"].startswith("https://api.dify.example")
|
||||
assert "evil.com" not in signed_payload["idp_callback_url"]
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
|
||||
@patch("controllers.openapi.oauth_device_sso.dify_config")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
|
||||
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
|
||||
def test_sso_initiate_fails_closed_when_console_api_url_unset(ee_feat, cfg, redis_cls, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
cfg.CONSOLE_API_URL = ""
|
||||
from services.oauth_device_flow import DeviceFlowStatus
|
||||
|
||||
state = MagicMock()
|
||||
state.status = DeviceFlowStatus.PENDING
|
||||
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH")
|
||||
assert resp.status_code == 502
|
||||
@ -0,0 +1,84 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.oauth_bearer import (
|
||||
BearerAuthenticator,
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
TokenKind,
|
||||
TokenKindRegistry,
|
||||
)
|
||||
|
||||
|
||||
def _registry_with_resolver(resolver) -> TokenKindRegistry:
|
||||
return TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
source="oauth_account",
|
||||
resolver=resolver,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
|
||||
def test_rate_limit_called_on_unknown_revoked_token(rl):
|
||||
resolver = MagicMock()
|
||||
resolver.resolve.return_value = None
|
||||
auth = BearerAuthenticator(_registry_with_resolver(resolver))
|
||||
|
||||
with pytest.raises(InvalidBearerError):
|
||||
auth.authenticate("dfoa_revokedtoken123")
|
||||
|
||||
rl.assert_called_once()
|
||||
resolver.resolve.assert_called_once()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
|
||||
def test_rate_limit_called_before_resolve(rl):
|
||||
call_order: list[str] = []
|
||||
rl.side_effect = lambda _h: call_order.append("rl")
|
||||
resolver = MagicMock()
|
||||
resolver.resolve.side_effect = lambda _h: call_order.append("resolve") or None
|
||||
auth = BearerAuthenticator(_registry_with_resolver(resolver))
|
||||
|
||||
with pytest.raises(InvalidBearerError):
|
||||
auth.authenticate("dfoa_xyz")
|
||||
|
||||
assert call_order == ["rl", "resolve"], f"expected rl before resolve, got {call_order}"
|
||||
|
||||
|
||||
def test_unknown_prefix_raises_generic_invalid_bearer():
|
||||
auth = BearerAuthenticator(
|
||||
TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
source="oauth_account",
|
||||
resolver=MagicMock(),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
with pytest.raises(InvalidBearerError) as exc:
|
||||
auth.authenticate("zzz_xyz")
|
||||
assert str(exc.value) == "invalid_bearer"
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.enforce_bearer_rate_limit")
|
||||
def test_revoked_token_raises_generic_invalid_bearer(rl):
|
||||
resolver = MagicMock()
|
||||
resolver.resolve.return_value = None
|
||||
auth = BearerAuthenticator(_registry_with_resolver(resolver))
|
||||
with pytest.raises(InvalidBearerError) as exc:
|
||||
auth.authenticate("dfoa_revoked")
|
||||
assert str(exc.value) == "invalid_bearer"
|
||||
Reference in New Issue
Block a user