mirror of
https://github.com/langgenius/dify.git
synced 2026-05-28 21:03:22 +08:00
fix typings
This commit is contained in:
@ -1,3 +1,4 @@
|
||||
from typing import Literal
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
@ -23,7 +24,7 @@ class DeploymentConfig(BaseSettings):
|
||||
default=False,
|
||||
)
|
||||
|
||||
EDITION: str = Field(
|
||||
EDITION: Literal["SELF_HOSTED", "CLOUD"] = Field(
|
||||
description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')",
|
||||
default="SELF_HOSTED",
|
||||
)
|
||||
|
||||
@ -5,6 +5,7 @@ compatibility without needing to be logged in. Mirrors the `_health` endpoint
|
||||
in `index.py`.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
@ -206,6 +206,7 @@ class AppListApi(Resource):
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
tenant_name: str | None = None
|
||||
if parsed_uuid is not None:
|
||||
app: App | None = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id or not is_openapi_visible(app):
|
||||
@ -250,7 +251,7 @@ class AppListApi(Resource):
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
tenant_name = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
|
||||
@ -77,10 +77,9 @@ def _coerce_subject_type(raw: object) -> SubjectType | None:
|
||||
return None
|
||||
if isinstance(raw, SubjectType):
|
||||
return raw
|
||||
try:
|
||||
if isinstance(raw, str):
|
||||
return SubjectType(raw)
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def _stringify(value: object) -> str | None:
|
||||
|
||||
@ -16,6 +16,7 @@ SSO branch lives in oauth_device_sso.py.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
@ -56,6 +57,7 @@ from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
@ -147,7 +149,7 @@ class OAuthDeviceTokenApi(Resource):
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
poll_payload: PollPayload | dict[str, Any] = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
@ -330,9 +332,10 @@ def _audit_cross_ip_if_needed(state) -> None:
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
def _build_account_poll_payload(account, tenant, mint) -> PollPayload:
|
||||
"""Account branch of the shared `PollPayload` contract. SSO-only fields
|
||||
(`subject_email`, `subject_issuer`) are intentionally omitted; see the
|
||||
`PollPayload` docstring in `services.oauth_device_flow`.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
@ -355,7 +358,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0].id
|
||||
|
||||
return {
|
||||
payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
@ -364,6 +367,7 @@ def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
return payload
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
|
||||
@ -56,6 +56,7 @@ from services.oauth_device_flow import (
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
PollPayload,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
@ -258,7 +259,10 @@ def approve_external():
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
# SSO branch of the shared PollPayload contract: account/workspace
|
||||
# fields are zero-filled (`None` / `[]`) for parity with the account
|
||||
# branch in `oauth_device._build_account_poll_payload`.
|
||||
poll_payload: PollPayload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
|
||||
@ -15,12 +15,13 @@ import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, SubjectType
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@ -73,6 +74,36 @@ class SlowDownDecision(StrEnum):
|
||||
SLOW_DOWN = "slow_down"
|
||||
|
||||
|
||||
class PollPayload(TypedDict):
|
||||
"""Body served by the unauthenticated poll endpoint
|
||||
(`POST /openapi/v1/oauth/device/token`) once approve has run.
|
||||
|
||||
A single shape across both branches so the CLI/SPA can parse one
|
||||
contract:
|
||||
|
||||
- ``account`` branch (built in `controllers.openapi.oauth_device.
|
||||
_build_account_poll_payload`) populates ``account`` + ``workspaces``
|
||||
+ ``default_workspace_id`` and omits the SSO-only fields.
|
||||
- ``external_sso`` branch (built in
|
||||
`controllers.openapi.oauth_device_sso.approve_external`) populates
|
||||
``subject_email`` + ``subject_issuer`` and zero-fills the
|
||||
account/workspace fields (``None`` / ``[]``).
|
||||
|
||||
Pre-rendering here means the unauthenticated poll handler doesn't
|
||||
re-query accounts/tenants for authz data.
|
||||
"""
|
||||
|
||||
token: str
|
||||
expires_at: str
|
||||
subject_type: SubjectType
|
||||
account: dict[str, object] | None
|
||||
workspaces: list[dict[str, object]]
|
||||
default_workspace_id: str | None
|
||||
token_id: str
|
||||
subject_email: NotRequired[str]
|
||||
subject_issuer: NotRequired[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlowState:
|
||||
"""``minted_token`` is plaintext between approve and the next poll;
|
||||
@ -91,7 +122,7 @@ class DeviceFlowState:
|
||||
created_at: str = ""
|
||||
created_ip: str = ""
|
||||
last_poll_at: str = ""
|
||||
poll_payload: dict | None = field(default=None)
|
||||
poll_payload: PollPayload | None = field(default=None)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
@ -192,7 +223,7 @@ class DeviceFlowRedis:
|
||||
minted_token: str,
|
||||
token_id: str,
|
||||
subject_issuer: str | None = None,
|
||||
poll_payload: dict | None = None,
|
||||
poll_payload: PollPayload | None = None,
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
|
||||
@ -21,7 +21,7 @@ from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import SurfaceCheck
|
||||
from controllers.openapi.auth.surface_gate import accept_subjects, check_surface
|
||||
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType
|
||||
|
||||
|
||||
@ -179,3 +179,51 @@ def test_surface_check_rejects_on_miss_and_emits_audit():
|
||||
with pytest.raises(Forbidden):
|
||||
step(_pipeline_ctx())
|
||||
emit.assert_called_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
|
||||
# ---------------------------------------------------------------------------
|
||||
#
|
||||
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
|
||||
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
|
||||
# dict-shaped context), `None` (attribute missing), or something unexpected
|
||||
# from a buggy upstream. The coercer must collapse all of that to
|
||||
# `SubjectType | None` so `check_surface` can do a clean set-membership
|
||||
# check and emit a clean audit payload.
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_none_for_none():
|
||||
assert _coerce_subject_type(None) is None
|
||||
|
||||
|
||||
def test_coerce_subject_type_returns_enum_instance_unchanged():
|
||||
# Identity matters: we don't want to round-trip through the string
|
||||
# constructor for an already-valid enum.
|
||||
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
|
||||
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("raw", "expected"),
|
||||
[
|
||||
("account", SubjectType.ACCOUNT),
|
||||
("external_sso", SubjectType.EXTERNAL_SSO),
|
||||
],
|
||||
)
|
||||
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
|
||||
assert _coerce_subject_type(raw) is expected
|
||||
|
||||
|
||||
def test_coerce_subject_type_raises_on_unknown_string():
|
||||
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
|
||||
# We surface that loudly rather than silently returning None, because
|
||||
# a string that *looks* like a subject type but isn't is almost
|
||||
# certainly an upstream bug worth catching.
|
||||
with pytest.raises(ValueError):
|
||||
_coerce_subject_type("not_a_subject")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
|
||||
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
|
||||
assert _coerce_subject_type(raw) is None
|
||||
|
||||
Reference in New Issue
Block a user