fix(openapi): close 4 critical OAuth device-flow security findings

1. Host-header injection (sso_initiate / sso_complete): replace
   request.host_url with dify_config.CONSOLE_API_URL via a
   _trusted_origin() helper that fails closed when unset. An
   attacker-controlled Host header on sso-initiate would otherwise
   be sealed into the signed state envelope, causing the IdP to
   redirect the victim's EE-signed SSO assertion to evil.com.

2. Unvalidated JWS claim payloads: add ExtSubjectAssertionClaims
   and ApprovalGrantClaimsPayload pydantic models and route every
   verified payload through model_validate. A signed-but-malformed
   blob now returns BadRequest('invalid_sso_assertion') or
   VerifyError('claim shape invalid') instead of crashing the
   handler with KeyError / 500. ApprovalGrantClaimsPayload is
   imported lazily inside verify_approval_grant to break the
   libs -> controllers cycle.

3. Timing-unsafe CSRF compare in approve_external: replace plain
   != with secrets.compare_digest.

4. Bearer rate-limit bypass on revoked tokens: move
   enforce_bearer_rate_limit to fire after sha256_hex but before
   resolver.resolve, so revoked-token replay is now bounded. Also
   collapse the two distinct error messages (unknown token prefix
   vs token unknown or revoked) into a single generic
   'invalid_bearer' to remove the prefix-validity oracle.

Tests: 4 new unit-test files cover each finding plus one updated
test for the new bearer error string. 744 tests pass.
This commit is contained in:
GareArc
2026-05-24 21:17:36 -07:00
parent da10ea017a
commit eba0041973
9 changed files with 319 additions and 21 deletions

View File

@ -30,7 +30,7 @@ def test_bearer_check_rejects_missing_header():
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx("xxx_abc"))

View File

@ -0,0 +1,73 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
@pytest.fixture
def app() -> Flask:
a = Flask(__name__)
a.config["TESTING"] = True
a.register_blueprint(openapi_bp)
return a
def _ee_features():
from services.feature_service import LicenseStatus
m = MagicMock()
m.license.status = LicenseStatus.ACTIVE
return m
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Flask):
ee_feat.return_value = _ee_features()
jws_mod.verify.return_value = {"issuer": "https://idp.example", "user_code": "ABCD-EFGH", "nonce": "n"}
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
jws_mod.KeySet.from_shared_secret.return_value = object()
jws_mod.VerifyError = Exception
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400, resp.data
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("libs.device_flow_security.FeatureService.get_system_features")
def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flask):
ee_feat.return_value = _ee_features()
jws_mod.verify.return_value = {"email": "x@y.com", "issuer": "", "user_code": "ABCD-EFGH", "nonce": "n"}
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
jws_mod.KeySet.from_shared_secret.return_value = object()
jws_mod.VerifyError = Exception
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
assert resp.status_code == 400
def test_verify_approval_grant_raises_on_missing_field():
from libs import device_flow_security
from libs import jws as jws_mod
class _FakeKeyset:
active_kid = "k"
def lookup(self, kid):
return b"secret"
keyset = _FakeKeyset()
incomplete = jws_mod.sign(
keyset,
payload={"subject_email": "x@y.com", "subject_issuer": "i", "user_code": "ABCD-EFGH", "nonce": "n"},
aud=jws_mod.AUD_APPROVAL_GRANT,
ttl_seconds=60,
)
with pytest.raises(jws_mod.VerifyError):
device_flow_security.verify_approval_grant(keyset, incomplete)

View File

@ -0,0 +1,22 @@
from __future__ import annotations
import ast
from pathlib import Path
def _repo_root() -> Path:
for parent in Path(__file__).resolve().parents:
if (parent / "api" / "pyproject.toml").exists():
return parent
raise RuntimeError("repo root not found")
def test_approve_external_uses_compare_digest_for_csrf():
src = (_repo_root() / "api" / "controllers" / "openapi" / "oauth_device_sso.py").read_text()
tree = ast.parse(src)
fn = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "approve_external")
fn_src = ast.unparse(fn)
assert "compare_digest" in fn_src, "approve_external must call secrets.compare_digest for CSRF"
assert "csrf_header != claims.csrf_token" not in fn_src, "approve_external must not use plain != on csrf_token"

View File

@ -0,0 +1,75 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
@pytest.fixture
def app() -> Flask:
a = Flask(__name__)
a.config["TESTING"] = True
a.register_blueprint(openapi_bp)
return a
def _ee_features():
from services.feature_service import LicenseStatus
m = MagicMock()
m.license.status = LicenseStatus.ACTIVE
return m
@patch("controllers.openapi.oauth_device_sso.EnterpriseService")
@patch("controllers.openapi.oauth_device_sso.jws")
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
@patch("controllers.openapi.oauth_device_sso.dify_config")
@patch("libs.device_flow_security.FeatureService.get_system_features")
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
def test_idp_callback_url_uses_console_api_url_not_host_header(ee_feat, cfg, redis_cls, jws_mod, ent, app: Flask):
ee_feat.return_value = _ee_features()
cfg.CONSOLE_API_URL = "https://api.dify.example"
state = MagicMock()
from services.oauth_device_flow import DeviceFlowStatus
state.status = DeviceFlowStatus.PENDING
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
jws_mod.KeySet.from_shared_secret.return_value = MagicMock()
jws_mod.sign.return_value = "signed-state"
jws_mod.AUD_STATE_ENVELOPE = "aud"
ent.initiate_device_flow_sso.return_value = {"url": "https://idp.example/auth"}
client = app.test_client()
client.get(
"/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH",
headers={"Host": "evil.com"},
)
args, kwargs = jws_mod.sign.call_args
signed_payload = args[1] if len(args) > 1 else kwargs["payload"]
assert signed_payload["idp_callback_url"].startswith("https://api.dify.example")
assert "evil.com" not in signed_payload["idp_callback_url"]
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
@patch("controllers.openapi.oauth_device_sso.dify_config")
@patch("libs.device_flow_security.FeatureService.get_system_features")
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
def test_sso_initiate_fails_closed_when_console_api_url_unset(ee_feat, cfg, redis_cls, app: Flask):
ee_feat.return_value = _ee_features()
cfg.CONSOLE_API_URL = ""
from services.oauth_device_flow import DeviceFlowStatus
state = MagicMock()
state.status = DeviceFlowStatus.PENDING
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
client = app.test_client()
resp = client.get("/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH")
assert resp.status_code == 502