diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 63f4dfba63..adeb0e7ea6 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -1,3 +1,4 @@ +from typing import Literal from pydantic import Field from pydantic_settings import BaseSettings @@ -23,7 +24,7 @@ class DeploymentConfig(BaseSettings): default=False, ) - EDITION: str = Field( + EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field( description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", default="SELF_HOSTED", ) diff --git a/api/controllers/openapi/_meta.py b/api/controllers/openapi/_meta.py index e1c380bf55..086fd9002f 100644 --- a/api/controllers/openapi/_meta.py +++ b/api/controllers/openapi/_meta.py @@ -5,6 +5,7 @@ compatibility without needing to be logged in. Mirrors the `_health` endpoint in `index.py`. """ +from typing import Literal from flask_restx import Resource from configs import dify_config diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index e536bb3b69..3c8b27f4f3 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -206,6 +206,7 @@ class AppListApi(Resource): else: parsed_uuid = None + tenant_name: str | None = None if parsed_uuid is not None: app: App | None = db.session.get(App, str(parsed_uuid)) if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app): @@ -250,7 +251,7 @@ class AppListApi(Resource): if pagination is None: return empty - tenant_name: str | None = None + tenant_name = None if pagination.items: tenant_name = db.session.execute( sa.select(Tenant.name).where(Tenant.id == workspace_id) diff --git a/api/controllers/openapi/auth/surface_gate.py b/api/controllers/openapi/auth/surface_gate.py index 1396f82728..e555506ba7 100644 --- a/api/controllers/openapi/auth/surface_gate.py +++ b/api/controllers/openapi/auth/surface_gate.py @@ -77,10 +77,9 @@ def _coerce_subject_type(raw: object) -> SubjectType | None: return None if isinstance(raw, SubjectType): return raw - try: + if isinstance(raw, str): return SubjectType(raw) - except ValueError: - return None + return None def _stringify(value: object) -> str | None: diff --git a/api/controllers/openapi/oauth_device.py b/api/controllers/openapi/oauth_device.py index 8a2aa781fe..87ca64f0e7 100644 --- a/api/controllers/openapi/oauth_device.py +++ b/api/controllers/openapi/oauth_device.py @@ -16,6 +16,7 @@ SSO branch lives in oauth_device_sso.py. from __future__ import annotations import logging +from typing import Any from flask import request from flask_login import login_required @@ -56,6 +57,7 @@ from services.oauth_device_flow import ( DeviceFlowRedis, DeviceFlowStatus, InvalidTransitionError, + PollPayload, SlowDownDecision, StateNotFoundError, mint_oauth_token, @@ -147,7 +149,7 @@ class OAuthDeviceTokenApi(Resource): if terminal.status is DeviceFlowStatus.DENIED: return {"error": "access_denied"}, 400 - poll_payload = terminal.poll_payload or {} + poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {} if "token" not in poll_payload: logger.error("device_flow: approved state missing poll_payload for %s", device_code) return {"error": "expired_token"}, 400 @@ -330,9 +332,10 @@ def _audit_cross_ip_if_needed(state) -> None: ) -def _build_account_poll_payload(account, tenant, mint) -> dict: - """Pre-render the poll-response body so the unauthenticated poll - handler doesn't re-query accounts/tenants for authz data. +def _build_account_poll_payload(account, tenant, mint) -> PollPayload: + """Account branch of the shared `PollPayload` contract. SSO-only fields + (`subject_email`, `subject_issuer`) are intentionally omitted; see the + `PollPayload` docstring in `services.oauth_device_flow`. """ from models import Tenant, TenantAccountJoin @@ -355,7 +358,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict: if default_ws_id is None and workspaces: default_ws_id = workspaces[0].id - return { + payload: PollPayload = { "token": mint.token, "expires_at": mint.expires_at.isoformat(), "subject_type": SubjectType.ACCOUNT, @@ -364,6 +367,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict: "default_workspace_id": default_ws_id, "token_id": str(mint.token_id), } + return payload def _emit_approve_audit(state, account, tenant, mint) -> None: diff --git a/api/controllers/openapi/oauth_device_sso.py b/api/controllers/openapi/oauth_device_sso.py index 6065cfe430..49866e1156 100644 --- a/api/controllers/openapi/oauth_device_sso.py +++ b/api/controllers/openapi/oauth_device_sso.py @@ -56,6 +56,7 @@ from services.oauth_device_flow import ( DeviceFlowRedis, DeviceFlowStatus, InvalidTransitionError, + PollPayload, StateNotFoundError, mint_oauth_token, oauth_ttl_days, @@ -258,7 +259,10 @@ def approve_external(): ttl_days=ttl_days, ) - poll_payload = { + # SSO branch of the shared PollPayload contract: account/workspace + # fields are zero-filled (`None` / `[]`) for parity with the account + # branch in `oauth_device._build_account_poll_payload`. + poll_payload: PollPayload = { "token": mint.token, "expires_at": mint.expires_at.isoformat(), "subject_type": SubjectType.EXTERNAL_SSO, diff --git a/api/services/oauth_device_flow.py b/api/services/oauth_device_flow.py index 11e92f8ae9..69bd295160 100644 --- a/api/services/oauth_device_flow.py +++ b/api/services/oauth_device_flow.py @@ -15,12 +15,13 @@ import uuid from dataclasses import asdict, dataclass, field from datetime import UTC, datetime, timedelta from enum import StrEnum +from typing import NotRequired, TypedDict from sqlalchemy import func, select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session, scoped_session -from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT +from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, SubjectType from models.oauth import OAuthAccessToken logger = logging.getLogger(__name__) @@ -73,6 +74,36 @@ class SlowDownDecision(StrEnum): SLOW_DOWN = "slow_down" +class PollPayload(TypedDict): + """Body served by the unauthenticated poll endpoint + (`POST /openapi/v1/oauth/device/token`) once approve has run. + + A single shape across both branches so the CLI/SPA can parse one + contract: + + - ``account`` branch (built in `controllers.openapi.oauth_device. + _build_account_poll_payload`) populates ``account`` + ``workspaces`` + + ``default_workspace_id`` and omits the SSO-only fields. + - ``external_sso`` branch (built in + `controllers.openapi.oauth_device_sso.approve_external`) populates + ``subject_email`` + ``subject_issuer`` and zero-fills the + account/workspace fields (``None`` / ``[]``). + + Pre-rendering here means the unauthenticated poll handler doesn't + re-query accounts/tenants for authz data. + """ + + token: str + expires_at: str + subject_type: SubjectType + account: dict[str, object] | None + workspaces: list[dict[str, object]] + default_workspace_id: str | None + token_id: str + subject_email: NotRequired[str] + subject_issuer: NotRequired[str] + + @dataclass class DeviceFlowState: """``minted_token`` is plaintext between approve and the next poll; @@ -91,7 +122,7 @@ class DeviceFlowState: created_at: str = "" created_ip: str = "" last_poll_at: str = "" - poll_payload: dict | None = field(default=None) + poll_payload: PollPayload | None = field(default=None) def to_json(self) -> str: return json.dumps(asdict(self)) @@ -192,7 +223,7 @@ class DeviceFlowRedis: minted_token: str, token_id: str, subject_issuer: str | None = None, - poll_payload: dict | None = None, + poll_payload: PollPayload | None = None, ) -> None: state = self._load_state(device_code) if state is None: diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py b/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py index 389ea06dc1..f14ad0c815 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py @@ -21,7 +21,7 @@ from werkzeug.exceptions import Forbidden from controllers.openapi.auth.context import Context from controllers.openapi.auth.steps import SurfaceCheck -from controllers.openapi.auth.surface_gate import accept_subjects, check_surface +from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface from libs.oauth_bearer import AuthContext, Scope, SubjectType @@ -179,3 +179,51 @@ def test_surface_check_rejects_on_miss_and_emits_audit(): with pytest.raises(Forbidden): step(_pipeline_ctx()) emit.assert_called_once() + + +# --------------------------------------------------------------------------- +# _coerce_subject_type — normalises whatever sat on ctx.subject_type +# --------------------------------------------------------------------------- +# +# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value +# could be a real enum (happy path), a raw string (e.g. rehydrated from a +# dict-shaped context), `None` (attribute missing), or something unexpected +# from a buggy upstream. The coercer must collapse all of that to +# `SubjectType | None` so `check_surface` can do a clean set-membership +# check and emit a clean audit payload. + + +def test_coerce_subject_type_returns_none_for_none(): + assert _coerce_subject_type(None) is None + + +def test_coerce_subject_type_returns_enum_instance_unchanged(): + # Identity matters: we don't want to round-trip through the string + # constructor for an already-valid enum. + assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT + assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO + + +@pytest.mark.parametrize( + ("raw", "expected"), + [ + ("account", SubjectType.ACCOUNT), + ("external_sso", SubjectType.EXTERNAL_SSO), + ], +) +def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType): + assert _coerce_subject_type(raw) is expected + + +def test_coerce_subject_type_raises_on_unknown_string(): + # Unknown strings reach `SubjectType(raw)` which raises ValueError. + # We surface that loudly rather than silently returning None, because + # a string that *looks* like a subject type but isn't is almost + # certainly an upstream bug worth catching. + with pytest.raises(ValueError): + _coerce_subject_type("not_a_subject") + + +@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}]) +def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object): + assert _coerce_subject_type(raw) is None