fix typings

This commit is contained in:
yunlu.wen
2026-05-22 18:15:28 +08:00
parent 2ff07b6311
commit d94e302045
8 changed files with 104 additions and 15 deletions

View File

@ -1,3 +1,4 @@
from typing import Literal
from pydantic import Field
from pydantic_settings import BaseSettings
@ -23,7 +24,7 @@ class DeploymentConfig(BaseSettings):
default=False,
)
EDITION: str = Field(
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
default="SELF_HOSTED",
)

View File

@ -5,6 +5,7 @@ compatibility without needing to be logged in. Mirrors the `_health` endpoint
in `index.py`.
"""
from typing import Literal
from flask_restx import Resource
from configs import dify_config

View File

@ -206,6 +206,7 @@ class AppListApi(Resource):
else:
parsed_uuid = None
tenant_name: str | None = None
if parsed_uuid is not None:
app: App | None = db.session.get(App, str(parsed_uuid))
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
@ -250,7 +251,7 @@ class AppListApi(Resource):
if pagination is None:
return empty
tenant_name: str | None = None
tenant_name = None
if pagination.items:
tenant_name = db.session.execute(
sa.select(Tenant.name).where(Tenant.id == workspace_id)

View File

@ -77,10 +77,9 @@ def _coerce_subject_type(raw: object) -> SubjectType | None:
return None
if isinstance(raw, SubjectType):
return raw
try:
if isinstance(raw, str):
return SubjectType(raw)
except ValueError:
return None
return None
def _stringify(value: object) -> str | None:

View File

@ -16,6 +16,7 @@ SSO branch lives in oauth_device_sso.py.
from __future__ import annotations
import logging
from typing import Any
from flask import request
from flask_login import login_required
@ -56,6 +57,7 @@ from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
SlowDownDecision,
StateNotFoundError,
mint_oauth_token,
@ -147,7 +149,7 @@ class OAuthDeviceTokenApi(Resource):
if terminal.status is DeviceFlowStatus.DENIED:
return {"error": "access_denied"}, 400
poll_payload = terminal.poll_payload or {}
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
if "token" not in poll_payload:
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
return {"error": "expired_token"}, 400
@ -330,9 +332,10 @@ def _audit_cross_ip_if_needed(state) -> None:
)
def _build_account_poll_payload(account, tenant, mint) -> dict:
"""Pre-render the poll-response body so the unauthenticated poll
handler doesn't re-query accounts/tenants for authz data.
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
"""Account branch of the shared `PollPayload` contract. SSO-only fields
(`subject_email`, `subject_issuer`) are intentionally omitted; see the
`PollPayload` docstring in `services.oauth_device_flow`.
"""
from models import Tenant, TenantAccountJoin
@ -355,7 +358,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
if default_ws_id is None and workspaces:
default_ws_id = workspaces[0].id
return {
payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.ACCOUNT,
@ -364,6 +367,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
"default_workspace_id": default_ws_id,
"token_id": str(mint.token_id),
}
return payload
def _emit_approve_audit(state, account, tenant, mint) -> None:

View File

@ -56,6 +56,7 @@ from services.oauth_device_flow import (
DeviceFlowRedis,
DeviceFlowStatus,
InvalidTransitionError,
PollPayload,
StateNotFoundError,
mint_oauth_token,
oauth_ttl_days,
@ -258,7 +259,10 @@ def approve_external():
ttl_days=ttl_days,
)
poll_payload = {
# SSO branch of the shared PollPayload contract: account/workspace
# fields are zero-filled (`None` / `[]`) for parity with the account
# branch in `oauth_device._build_account_poll_payload`.
poll_payload: PollPayload = {
"token": mint.token,
"expires_at": mint.expires_at.isoformat(),
"subject_type": SubjectType.EXTERNAL_SSO,

View File

@ -15,12 +15,13 @@ import uuid
from dataclasses import asdict, dataclass, field
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from typing import NotRequired, TypedDict
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.orm import Session, scoped_session
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, SubjectType
from models.oauth import OAuthAccessToken
logger = logging.getLogger(__name__)
@ -73,6 +74,36 @@ class SlowDownDecision(StrEnum):
SLOW_DOWN = "slow_down"
class PollPayload(TypedDict):
"""Body served by the unauthenticated poll endpoint
(`POST /openapi/v1/oauth/device/token`) once approve has run.
A single shape across both branches so the CLI/SPA can parse one
contract:
- ``account`` branch (built in `controllers.openapi.oauth_device.
_build_account_poll_payload`) populates ``account`` + ``workspaces``
+ ``default_workspace_id`` and omits the SSO-only fields.
- ``external_sso`` branch (built in
`controllers.openapi.oauth_device_sso.approve_external`) populates
``subject_email`` + ``subject_issuer`` and zero-fills the
account/workspace fields (``None`` / ``[]``).
Pre-rendering here means the unauthenticated poll handler doesn't
re-query accounts/tenants for authz data.
"""
token: str
expires_at: str
subject_type: SubjectType
account: dict[str, object] | None
workspaces: list[dict[str, object]]
default_workspace_id: str | None
token_id: str
subject_email: NotRequired[str]
subject_issuer: NotRequired[str]
@dataclass
class DeviceFlowState:
"""``minted_token`` is plaintext between approve and the next poll;
@ -91,7 +122,7 @@ class DeviceFlowState:
created_at: str = ""
created_ip: str = ""
last_poll_at: str = ""
poll_payload: dict | None = field(default=None)
poll_payload: PollPayload | None = field(default=None)
def to_json(self) -> str:
return json.dumps(asdict(self))
@ -192,7 +223,7 @@ class DeviceFlowRedis:
minted_token: str,
token_id: str,
subject_issuer: str | None = None,
poll_payload: dict | None = None,
poll_payload: PollPayload | None = None,
) -> None:
state = self._load_state(device_code)
if state is None:

View File

@ -21,7 +21,7 @@ from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import SurfaceCheck
from controllers.openapi.auth.surface_gate import accept_subjects, check_surface
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
from libs.oauth_bearer import AuthContext, Scope, SubjectType
@ -179,3 +179,51 @@ def test_surface_check_rejects_on_miss_and_emits_audit():
with pytest.raises(Forbidden):
step(_pipeline_ctx())
emit.assert_called_once()
# ---------------------------------------------------------------------------
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
# ---------------------------------------------------------------------------
#
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
# dict-shaped context), `None` (attribute missing), or something unexpected
# from a buggy upstream. The coercer must collapse all of that to
# `SubjectType | None` so `check_surface` can do a clean set-membership
# check and emit a clean audit payload.
def test_coerce_subject_type_returns_none_for_none():
assert _coerce_subject_type(None) is None
def test_coerce_subject_type_returns_enum_instance_unchanged():
# Identity matters: we don't want to round-trip through the string
# constructor for an already-valid enum.
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
@pytest.mark.parametrize(
("raw", "expected"),
[
("account", SubjectType.ACCOUNT),
("external_sso", SubjectType.EXTERNAL_SSO),
],
)
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
assert _coerce_subject_type(raw) is expected
def test_coerce_subject_type_raises_on_unknown_string():
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
# We surface that loudly rather than silently returning None, because
# a string that *looks* like a subject type but isn't is almost
# certainly an upstream bug worth catching.
with pytest.raises(ValueError):
_coerce_subject_type("not_a_subject")
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
assert _coerce_subject_type(raw) is None