Files
dify/api/libs/jws.py
GareArc fe8510ad1a feat(api,web): OAuth 2.0 device flow + bearer auth (RFC 8628)
Adds a CLI-friendly authorization flow so difyctl (and future
non-browser clients) can obtain user-scoped tokens without copy-
pasting cookies or raw API keys. Two grant paths share one device
flow surface:

  1. Account branch — user signs in via the existing /signin
     methods, /device page calls console-authed approve, mints a
     dfoa_ token tied to (account_id, tenant).
  2. External-SSO branch (EE) — /v1/oauth/device/sso-initiate signs
     an SSOState envelope, hands off to Enterprise's external ACS,
     receives a signed external-subject assertion, mints a dfoe_
     token tied to (subject_email, subject_issuer).

API surface (all under /v1, EE-only endpoints 404 on CE):

  POST   /v1/oauth/device/code              — RFC 8628 start
  POST   /v1/oauth/device/token             — RFC 8628 poll
  GET    /v1/oauth/device/lookup            — pre-validate user_code
  GET    /v1/oauth/device/sso-initiate      — SSO branch entry
  GET    /v1/device/sso-complete            — SSO callback sink
  GET    /v1/oauth/device/approval-context  — /device cookie probe
  POST   /v1/oauth/device/approve-external  — SSO approve
  GET    /v1/me                             — bearer subject lookup
  DELETE /v1/oauth/authorizations/self      — self-revoke
  POST   /console/api/oauth/device/approve  — account approve
  POST   /console/api/oauth/device/deny     — account deny

Core primitives:
- libs/oauth_bearer.py: prefix-keyed TokenKindRegistry +
  BearerAuthenticator + validate_bearer decorator. Two-tier scope
  (full vs apps:run) stamped from the registry, never from the DB.
- libs/jws.py: HS256 compact JWS keyed on the shared Dify
  SECRET_KEY — same key-set verifies the SSOState envelope, the
  external-subject assertion (minted by Enterprise), and the
  approval-grant cookie.
- libs/device_flow_security.py: enterprise_only gate, approval-
  grant cookie mint/verify/consume (Path=/v1/oauth/device,
  HttpOnly, SameSite=Lax, Secure follows is_secure()), anti-
  framing headers.
- libs/rate_limit.py: typed RateLimit / RateLimitScope dispatch
  with composite-key buckets; both decorator + imperative form.
- services/oauth_device_flow.py: Redis state machine (PENDING ->
  APPROVED|DENIED with atomic consume-on-poll), token mint via
  partial unique index uq_oauth_active_per_device (rotates in
  place), env-driven TTL policy.

Storage: oauth_access_tokens table with partial unique index on
(subject_email, subject_issuer, client_id, device_label) WHERE
revoked_at IS NULL. account_id NULL distinguishes external-SSO
rows. Hard-expire is CAS UPDATE (revoked_at + nullify token_hash)
so audit events keep their token_id. Retention pruner DELETEs
revoked + zombie-expired rows past OAUTH_ACCESS_TOKEN_RETENTION_DAYS.

Frontend: /device page with code-entry, chooser (account vs SSO),
authorize-account, authorize-sso views. SSO branch detaches from
the URL user_code and reads everything from the cookie via
/approval-context. Anti-framing headers on all responses.

Wiring: ENABLE_OAUTH_BEARER feature flag; ext_oauth_bearer binds
the authenticator at startup; clean_oauth_access_tokens_task
scheduled in ext_celery.

Spec: docs/specs/v1.0/server/{device-flow,tokens,middleware,security}.md
2026-04-26 20:06:43 -07:00

107 lines
3.5 KiB
Python

"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
state envelope, external subject assertion, and approval-grant cookie —
all three share one key-set so api ↔ enterprise can verify each other.
"""
from __future__ import annotations
from datetime import UTC, datetime, timedelta
import jwt
from configs import dify_config
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
ACTIVE_KID_V1 = "dify-shared-v1"
class KeySetError(Exception):
pass
class KeySet:
"""``from_entries`` reserves multi-kid construction for rotation slots."""
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
if active_kid not in entries:
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
if not entries[active_kid]:
raise KeySetError(f"active kid {active_kid!r} has empty secret")
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
self._active_kid = active_kid
@classmethod
def from_shared_secret(cls) -> "KeySet":
secret = dify_config.SECRET_KEY
if not secret:
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
@classmethod
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> "KeySet":
return cls(entries, active_kid)
@property
def active_kid(self) -> str:
return self._active_kid
def lookup(self, kid: str) -> bytes | None:
return self._entries.get(kid)
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
"""``iat`` + ``exp`` are injected here; callers must not set them."""
if "aud" in payload or "iat" in payload or "exp" in payload:
raise ValueError("reserved claim present in payload (aud/iat/exp)")
if ttl_seconds <= 0:
raise ValueError("ttl_seconds must be positive")
kid = keyset.active_kid
secret = keyset.lookup(kid)
if secret is None:
raise KeySetError(f"active kid {kid!r} lookup miss")
iat = datetime.now(UTC)
exp = iat + timedelta(seconds=ttl_seconds)
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
return jwt.encode(
claims,
secret,
algorithm="HS256",
headers={"kid": kid, "typ": "JWT"},
)
class VerifyError(Exception):
pass
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
"""Unknown kid is rejected — never fall back to the active kid, since
a past kid value would otherwise be forgeable by anyone who saw it.
"""
try:
header = jwt.get_unverified_header(token)
except jwt.PyJWTError as e:
raise VerifyError(f"decode header: {e}") from e
kid = header.get("kid")
if not kid:
raise VerifyError("no kid in header")
secret = keyset.lookup(kid)
if secret is None:
raise VerifyError(f"unknown kid {kid!r}")
try:
return jwt.decode(
token,
secret,
algorithms=["HS256"],
audience=expected_aud,
)
except jwt.ExpiredSignatureError as e:
raise VerifyError("token expired") from e
except jwt.InvalidAudienceError as e:
raise VerifyError("aud mismatch") from e
except jwt.PyJWTError as e:
raise VerifyError(f"decode: {e}") from e