mirror of
https://github.com/langgenius/dify.git
synced 2026-06-08 09:27:39 +08:00
fix(openapi): close 4 critical OAuth device-flow security findings
1. Host-header injection (sso_initiate / sso_complete): replace
request.host_url with dify_config.CONSOLE_API_URL via a
_trusted_origin() helper that fails closed when unset. An
attacker-controlled Host header on sso-initiate would otherwise
be sealed into the signed state envelope, causing the IdP to
redirect the victim's EE-signed SSO assertion to evil.com.
2. Unvalidated JWS claim payloads: add ExtSubjectAssertionClaims
and ApprovalGrantClaimsPayload pydantic models and route every
verified payload through model_validate. A signed-but-malformed
blob now returns BadRequest('invalid_sso_assertion') or
VerifyError('claim shape invalid') instead of crashing the
handler with KeyError / 500. ApprovalGrantClaimsPayload is
imported lazily inside verify_approval_grant to break the
libs -> controllers cycle.
3. Timing-unsafe CSRF compare in approve_external: replace plain
!= with secrets.compare_digest.
4. Bearer rate-limit bypass on revoked tokens: move
enforce_bearer_rate_limit to fire after sha256_hex but before
resolver.resolve, so revoked-token replay is now bounded. Also
collapse the two distinct error messages (unknown token prefix
vs token unknown or revoked) into a single generic
'invalid_bearer' to remove the prefix-validity oracle.
Tests: 4 new unit-test files cover each finding plus one updated
test for the new bearer error string. 744 tests pass.
This commit is contained in:
@ -30,7 +30,7 @@ def test_bearer_check_rejects_missing_header():
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context(), pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx("xxx_abc"))
|
||||
|
||||
@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
a = Flask(__name__)
|
||||
a.config["TESTING"] = True
|
||||
a.register_blueprint(openapi_bp)
|
||||
return a
|
||||
|
||||
|
||||
def _ee_features():
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
m = MagicMock()
|
||||
m.license.status = LicenseStatus.ACTIVE
|
||||
return m
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_rejects_assertion_missing_email(ee_feat, jws_mod, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
jws_mod.verify.return_value = {"issuer": "https://idp.example", "user_code": "ABCD-EFGH", "nonce": "n"}
|
||||
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
|
||||
jws_mod.KeySet.from_shared_secret.return_value = object()
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400, resp.data
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
def test_sso_complete_rejects_assertion_empty_issuer(ee_feat, jws_mod, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
jws_mod.verify.return_value = {"email": "x@y.com", "issuer": "", "user_code": "ABCD-EFGH", "nonce": "n"}
|
||||
jws_mod.AUD_EXT_SUBJECT_ASSERTION = "aud"
|
||||
jws_mod.KeySet.from_shared_secret.return_value = object()
|
||||
jws_mod.VerifyError = Exception
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-complete?sso_assertion=blob")
|
||||
assert resp.status_code == 400
|
||||
|
||||
|
||||
def test_verify_approval_grant_raises_on_missing_field():
|
||||
from libs import device_flow_security
|
||||
from libs import jws as jws_mod
|
||||
|
||||
class _FakeKeyset:
|
||||
active_kid = "k"
|
||||
|
||||
def lookup(self, kid):
|
||||
return b"secret"
|
||||
|
||||
keyset = _FakeKeyset()
|
||||
incomplete = jws_mod.sign(
|
||||
keyset,
|
||||
payload={"subject_email": "x@y.com", "subject_issuer": "i", "user_code": "ABCD-EFGH", "nonce": "n"},
|
||||
aud=jws_mod.AUD_APPROVAL_GRANT,
|
||||
ttl_seconds=60,
|
||||
)
|
||||
with pytest.raises(jws_mod.VerifyError):
|
||||
device_flow_security.verify_approval_grant(keyset, incomplete)
|
||||
@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import ast
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
for parent in Path(__file__).resolve().parents:
|
||||
if (parent / "api" / "pyproject.toml").exists():
|
||||
return parent
|
||||
raise RuntimeError("repo root not found")
|
||||
|
||||
|
||||
def test_approve_external_uses_compare_digest_for_csrf():
|
||||
src = (_repo_root() / "api" / "controllers" / "openapi" / "oauth_device_sso.py").read_text()
|
||||
tree = ast.parse(src)
|
||||
|
||||
fn = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef) and n.name == "approve_external")
|
||||
fn_src = ast.unparse(fn)
|
||||
|
||||
assert "compare_digest" in fn_src, "approve_external must call secrets.compare_digest for CSRF"
|
||||
assert "csrf_header != claims.csrf_token" not in fn_src, "approve_external must not use plain != on csrf_token"
|
||||
@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
a = Flask(__name__)
|
||||
a.config["TESTING"] = True
|
||||
a.register_blueprint(openapi_bp)
|
||||
return a
|
||||
|
||||
|
||||
def _ee_features():
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
m = MagicMock()
|
||||
m.license.status = LicenseStatus.ACTIVE
|
||||
return m
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.EnterpriseService")
|
||||
@patch("controllers.openapi.oauth_device_sso.jws")
|
||||
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
|
||||
@patch("controllers.openapi.oauth_device_sso.dify_config")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
|
||||
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
|
||||
def test_idp_callback_url_uses_console_api_url_not_host_header(ee_feat, cfg, redis_cls, jws_mod, ent, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
cfg.CONSOLE_API_URL = "https://api.dify.example"
|
||||
state = MagicMock()
|
||||
from services.oauth_device_flow import DeviceFlowStatus
|
||||
|
||||
state.status = DeviceFlowStatus.PENDING
|
||||
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
|
||||
jws_mod.KeySet.from_shared_secret.return_value = MagicMock()
|
||||
jws_mod.sign.return_value = "signed-state"
|
||||
jws_mod.AUD_STATE_ENVELOPE = "aud"
|
||||
ent.initiate_device_flow_sso.return_value = {"url": "https://idp.example/auth"}
|
||||
|
||||
client = app.test_client()
|
||||
client.get(
|
||||
"/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH",
|
||||
headers={"Host": "evil.com"},
|
||||
)
|
||||
|
||||
args, kwargs = jws_mod.sign.call_args
|
||||
signed_payload = args[1] if len(args) > 1 else kwargs["payload"]
|
||||
assert signed_payload["idp_callback_url"].startswith("https://api.dify.example")
|
||||
assert "evil.com" not in signed_payload["idp_callback_url"]
|
||||
|
||||
|
||||
@patch("controllers.openapi.oauth_device_sso.DeviceFlowRedis")
|
||||
@patch("controllers.openapi.oauth_device_sso.dify_config")
|
||||
@patch("libs.device_flow_security.FeatureService.get_system_features")
|
||||
@patch("libs.rate_limit.RateLimiter.is_rate_limited", new=MagicMock(return_value=False))
|
||||
@patch("libs.rate_limit.RateLimiter.increment_rate_limit", new=MagicMock())
|
||||
def test_sso_initiate_fails_closed_when_console_api_url_unset(ee_feat, cfg, redis_cls, app: Flask):
|
||||
ee_feat.return_value = _ee_features()
|
||||
cfg.CONSOLE_API_URL = ""
|
||||
from services.oauth_device_flow import DeviceFlowStatus
|
||||
|
||||
state = MagicMock()
|
||||
state.status = DeviceFlowStatus.PENDING
|
||||
redis_cls.return_value.load_by_user_code.return_value = ("dc_x", state)
|
||||
|
||||
client = app.test_client()
|
||||
resp = client.get("/openapi/v1/oauth/device/sso-initiate?user_code=ABCD-EFGH")
|
||||
assert resp.status_code == 502
|
||||
Reference in New Issue
Block a user