mirror of
https://github.com/langgenius/dify.git
synced 2026-05-10 12:26:15 +08:00
Compare commits
53 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c0f33d0d7 | |||
| 668e9c7864 | |||
| a7c481ce87 | |||
| 507eb1f52f | |||
| 8e2ab1367b | |||
| 0c568623d7 | |||
| fb7b8dc151 | |||
| 4bc1046f14 | |||
| f2ec17be9b | |||
| 6532b4d161 | |||
| 4bfc4af590 | |||
| 1fb7329327 | |||
| 40ae39a3a3 | |||
| 35c08f7c3d | |||
| 7b6ceaebea | |||
| 35d9b6a0f8 | |||
| d1c1c04615 | |||
| 04ebf8a92f | |||
| 6f3c2fe97b | |||
| 03cd16fc44 | |||
| 3a6901e718 | |||
| 25034612b8 | |||
| 87620050d7 | |||
| e006eb7a4b | |||
| 305de57eff | |||
| 069fdd4894 | |||
| 783dfe38a0 | |||
| 86ba361ff1 | |||
| 591048d7c2 | |||
| 8a62c1d915 | |||
| b083c910b3 | |||
| 9b2a37ceff | |||
| cf5ebe9430 | |||
| 85c3f9cbf8 | |||
| d98fe7916a | |||
| 0b3b0b5ce8 | |||
| eb5ef3dba5 | |||
| a07b32274a | |||
| 2a38df2b7f | |||
| 71e9e8dda6 | |||
| 772f450b29 | |||
| 390f1f74db | |||
| b7bd9c19ed | |||
| e93821af46 | |||
| 9408759954 | |||
| fe9412af5d | |||
| 218ef6a447 | |||
| 501c0b8746 | |||
| 4214583ae5 | |||
| 73771cb58c | |||
| f5f224f49d | |||
| 813da349ec | |||
| fe8510ad1a |
@ -159,6 +159,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_logstore,
|
||||
ext_mail,
|
||||
ext_migrate,
|
||||
ext_oauth_bearer,
|
||||
ext_orjson,
|
||||
ext_otel,
|
||||
ext_proxy_fix,
|
||||
@ -203,6 +204,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
ext_oauth_bearer,
|
||||
]
|
||||
for ext in extensions:
|
||||
short_name = ext.__name__.split(".")[-1]
|
||||
|
||||
@ -499,6 +499,35 @@ class HttpConfig(BaseSettings):
|
||||
def WEB_API_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return self.inner_WEB_API_CORS_ALLOW_ORIGINS.split(",")
|
||||
|
||||
inner_OPENAPI_CORS_ALLOW_ORIGINS: str = Field(
|
||||
description=(
|
||||
"Comma-separated allowlist for /openapi/v1/* CORS. "
|
||||
"Default empty = same-origin only. Browser-cookie routes within "
|
||||
"the group reject cross-origin OPTIONS regardless of this list."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_CORS_ALLOW_ORIGINS"),
|
||||
default="",
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def OPENAPI_CORS_ALLOW_ORIGINS(self) -> list[str]:
|
||||
return [o for o in self.inner_OPENAPI_CORS_ALLOW_ORIGINS.split(",") if o]
|
||||
|
||||
inner_OPENAPI_KNOWN_CLIENT_IDS: str = Field(
|
||||
description=(
|
||||
"Comma-separated client_id values accepted at "
|
||||
"POST /openapi/v1/oauth/device/code. New CLIs / SDKs added here "
|
||||
"without code changes. Unknown client_id returns 400 unsupported_client."
|
||||
),
|
||||
validation_alias=AliasChoices("OPENAPI_KNOWN_CLIENT_IDS"),
|
||||
default="difyctl",
|
||||
)
|
||||
|
||||
@computed_field # type: ignore[misc]
|
||||
@property
|
||||
def OPENAPI_KNOWN_CLIENT_IDS(self) -> frozenset[str]:
|
||||
return frozenset(c for c in self.inner_OPENAPI_KNOWN_CLIENT_IDS.split(",") if c)
|
||||
|
||||
HTTP_REQUEST_MAX_CONNECT_TIMEOUT: int = Field(
|
||||
ge=1, description="Maximum connection timeout in seconds for HTTP requests", default=10
|
||||
)
|
||||
@ -874,6 +903,17 @@ class AuthConfig(BaseSettings):
|
||||
default=86400,
|
||||
)
|
||||
|
||||
ENABLE_OAUTH_BEARER: bool = Field(
|
||||
description="Enable OAuth bearer authentication (device-flow + Service API /v1/* bearer middleware).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
OPENAPI_RATE_LIMIT_PER_TOKEN: PositiveInt = Field(
|
||||
description="Per-token rate limit on /openapi/v1/* (requests per minute). "
|
||||
"Bucket keyed on sha256(token), shared across api replicas via Redis.",
|
||||
default=60,
|
||||
)
|
||||
|
||||
|
||||
class ModerationConfig(BaseSettings):
|
||||
"""
|
||||
@ -1148,6 +1188,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable scheduled workflow run cleanup task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK: bool = Field(
|
||||
description="Enable scheduled cleanup of revoked/expired OAuth access-token rows past retention.",
|
||||
default=True,
|
||||
)
|
||||
OAUTH_ACCESS_TOKEN_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Days to retain revoked OAuth access-token rows before deletion.",
|
||||
default=30,
|
||||
)
|
||||
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
|
||||
description="Enable mail clean document notify task",
|
||||
default=False,
|
||||
|
||||
41
api/controllers/openapi/__init__.py
Normal file
41
api/controllers/openapi/__init__.py
Normal file
@ -0,0 +1,41 @@
|
||||
from flask import Blueprint
|
||||
from flask_restx import Namespace
|
||||
|
||||
from libs.device_flow_security import attach_anti_framing
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint("openapi", __name__, url_prefix="/openapi/v1")
|
||||
attach_anti_framing(bp)
|
||||
|
||||
api = ExternalApi(
|
||||
bp,
|
||||
version="1.0",
|
||||
title="OpenAPI",
|
||||
description="User-scoped programmatic API (bearer auth)",
|
||||
)
|
||||
|
||||
openapi_ns = Namespace("openapi", description="User-scoped operations", path="/")
|
||||
|
||||
from . import (
|
||||
account,
|
||||
app_run,
|
||||
apps,
|
||||
apps_permitted,
|
||||
index,
|
||||
oauth_device,
|
||||
oauth_device_sso,
|
||||
workspaces,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"account",
|
||||
"app_run",
|
||||
"apps",
|
||||
"apps_permitted",
|
||||
"index",
|
||||
"oauth_device",
|
||||
"oauth_device_sso",
|
||||
"workspaces",
|
||||
]
|
||||
|
||||
api.add_namespace(openapi_ns)
|
||||
33
api/controllers/openapi/_audit.py
Normal file
33
api/controllers/openapi/_audit.py
Normal file
@ -0,0 +1,33 @@
|
||||
"""Audit emission for openapi app-run endpoints.
|
||||
|
||||
Pattern: logger.info with extra={"audit": True, "event": "app.run.openapi", ...}
|
||||
matches the existing oauth_device convention. The EE OTel exporter consults
|
||||
its own allowlist to decide whether to ship the line.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EVENT_APP_RUN_OPENAPI = "app.run.openapi"
|
||||
|
||||
|
||||
def emit_app_run(*, app_id: str, tenant_id: str, caller_kind: str, mode: str) -> None:
|
||||
logger.info(
|
||||
"audit: %s app_id=%s tenant_id=%s caller_kind=%s mode=%s",
|
||||
EVENT_APP_RUN_OPENAPI,
|
||||
app_id,
|
||||
tenant_id,
|
||||
caller_kind,
|
||||
mode,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": EVENT_APP_RUN_OPENAPI,
|
||||
"app_id": app_id,
|
||||
"tenant_id": tenant_id,
|
||||
"caller_kind": caller_kind,
|
||||
"mode": mode,
|
||||
},
|
||||
)
|
||||
143
api/controllers/openapi/_input_schema.py
Normal file
143
api/controllers/openapi/_input_schema.py
Normal file
@ -0,0 +1,143 @@
|
||||
"""Server-side JSON Schema derivation from Dify `user_input_form`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models import App
|
||||
from models.model import AppMode
|
||||
|
||||
JSON_SCHEMA_DRAFT = "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
EMPTY_INPUT_SCHEMA: dict[str, Any] = {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
_CHAT_FAMILY = frozenset({AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT})
|
||||
|
||||
|
||||
def _file_object_shape() -> dict[str, Any]:
|
||||
"""Single-file value shape. Forward-compat placeholder; refine when file-API contract pins."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"type": {"type": "string"},
|
||||
"transfer_method": {"type": "string"},
|
||||
"url": {"type": "string"},
|
||||
"upload_file_id": {"type": "string"},
|
||||
},
|
||||
"additionalProperties": True,
|
||||
}
|
||||
|
||||
|
||||
def _row_to_schema(row_type: str, row: dict[str, Any]) -> dict[str, Any] | None:
|
||||
label = row.get("label") or row.get("variable", "")
|
||||
base: dict[str, Any] = {"title": label} if label else {}
|
||||
|
||||
if row_type in ("text-input", "paragraph"):
|
||||
out = {"type": "string"} | base
|
||||
max_length = row.get("max_length")
|
||||
if isinstance(max_length, int) and max_length > 0:
|
||||
out["maxLength"] = max_length
|
||||
return out
|
||||
|
||||
if row_type == "select":
|
||||
return {"type": "string"} | base | {"enum": list(row.get("options") or [])}
|
||||
|
||||
if row_type == "number":
|
||||
return {"type": "number"} | base
|
||||
|
||||
if row_type == "file":
|
||||
return _file_object_shape() | base
|
||||
|
||||
if row_type == "file-list":
|
||||
return {
|
||||
"type": "array",
|
||||
"items": _file_object_shape(),
|
||||
} | base
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _form_to_jsonschema(form: list[dict[str, Any]]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Translate a user_input_form row list into (properties, required-list).
|
||||
|
||||
Each row is a single-key dict: `{"text-input": {variable, label, required, ...}}`.
|
||||
Unknown variable types are skipped (forward-compat).
|
||||
"""
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
for row in form:
|
||||
if not isinstance(row, dict) or len(row) != 1:
|
||||
continue
|
||||
((row_type, row_body),) = row.items()
|
||||
if not isinstance(row_body, dict):
|
||||
continue
|
||||
variable = row_body.get("variable")
|
||||
if not variable:
|
||||
continue
|
||||
schema = _row_to_schema(row_type, row_body)
|
||||
if schema is None:
|
||||
continue
|
||||
properties[variable] = schema
|
||||
if row_body.get("required"):
|
||||
required.append(variable)
|
||||
return properties, required
|
||||
|
||||
|
||||
def resolve_app_config(app: App) -> tuple[dict[str, Any], list[dict[str, Any]]]:
|
||||
"""Resolve `(features_dict, user_input_form)` for parameters / schema derivation.
|
||||
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
if app.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
workflow = app.workflow
|
||||
if workflow is None:
|
||||
raise AppUnavailableError()
|
||||
return (
|
||||
workflow.features_dict,
|
||||
cast(list[dict[str, Any]], workflow.user_input_form(to_old_structure=True)),
|
||||
)
|
||||
|
||||
app_model_config = app.app_model_config
|
||||
if app_model_config is None:
|
||||
raise AppUnavailableError()
|
||||
features_dict = cast(dict[str, Any], app_model_config.to_dict())
|
||||
return features_dict, cast(list[dict[str, Any]], features_dict.get("user_input_form", []))
|
||||
|
||||
|
||||
def build_input_schema(app: App) -> dict[str, Any]:
|
||||
"""Derive Draft 2020-12 JSON Schema from `user_input_form` + app mode.
|
||||
|
||||
chat / agent-chat / advanced-chat: top-level `query` (required, minLength=1) + `inputs` object.
|
||||
completion / workflow: `inputs` object only.
|
||||
Raises `AppUnavailableError` on misconfigured apps.
|
||||
"""
|
||||
_, user_input_form = resolve_app_config(app)
|
||||
inputs_props, inputs_required = _form_to_jsonschema(user_input_form)
|
||||
|
||||
properties: dict[str, Any] = {}
|
||||
required: list[str] = []
|
||||
|
||||
if app.mode in _CHAT_FAMILY:
|
||||
properties["query"] = {"type": "string", "minLength": 1}
|
||||
required.append("query")
|
||||
|
||||
properties["inputs"] = {
|
||||
"type": "object",
|
||||
"properties": inputs_props,
|
||||
"required": inputs_required,
|
||||
"additionalProperties": False,
|
||||
}
|
||||
required.append("inputs")
|
||||
|
||||
return {
|
||||
"$schema": JSON_SCHEMA_DRAFT,
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required,
|
||||
}
|
||||
112
api/controllers/openapi/_models.py
Normal file
112
api/controllers/openapi/_models.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Shared response substructures for openapi endpoints."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Server-side cap on `limit` query param for any /openapi/v1/* list endpoint.
|
||||
# Sibling endpoints (`/apps`, `/account/sessions`, future routes) all clamp to
|
||||
# this; do not introduce per-endpoint caps without raising the constant.
|
||||
MAX_PAGE_LIMIT = 200
|
||||
|
||||
|
||||
class UsageInfo(BaseModel):
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class MessageMetadata(BaseModel):
|
||||
usage: UsageInfo | None = None
|
||||
retriever_resources: list[dict[str, Any]] = []
|
||||
|
||||
|
||||
class PaginationEnvelope[T](BaseModel):
|
||||
"""Canonical pagination envelope for `/openapi/v1/*` list endpoints."""
|
||||
|
||||
page: int
|
||||
limit: int
|
||||
total: int
|
||||
has_more: bool
|
||||
data: list[T]
|
||||
|
||||
@classmethod
|
||||
def build(cls, *, page: int, limit: int, total: int, items: list[T]) -> PaginationEnvelope[T]:
|
||||
return cls(page=page, limit=limit, total=total, has_more=page * limit < total, data=items)
|
||||
|
||||
|
||||
class AppListRow(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
tags: list[dict[str, str]] = []
|
||||
updated_at: str | None = None
|
||||
created_by_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
|
||||
|
||||
class AppInfoResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str | None = None
|
||||
mode: str
|
||||
author: str | None = None
|
||||
tags: list[dict[str, str]] = []
|
||||
|
||||
|
||||
class AppDescribeInfo(AppInfoResponse):
|
||||
updated_at: str | None = None
|
||||
service_api_enabled: bool
|
||||
|
||||
|
||||
class AppDescribeResponse(BaseModel):
|
||||
info: AppDescribeInfo | None = None
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
conversation_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class CompletionMessageResponse(BaseModel):
|
||||
event: str
|
||||
task_id: str
|
||||
id: str
|
||||
message_id: str
|
||||
mode: str
|
||||
answer: str
|
||||
metadata: MessageMetadata = Field(default_factory=MessageMetadata)
|
||||
created_at: int
|
||||
|
||||
|
||||
class WorkflowRunData(BaseModel):
|
||||
id: str
|
||||
workflow_id: str
|
||||
status: str
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
error: str | None = None
|
||||
elapsed_time: float | None = None
|
||||
total_tokens: int | None = None
|
||||
total_steps: int | None = None
|
||||
created_at: int | None = None
|
||||
finished_at: int | None = None
|
||||
|
||||
|
||||
class WorkflowRunResponse(BaseModel):
|
||||
workflow_run_id: str
|
||||
task_id: str
|
||||
mode: Literal["workflow"] = "workflow"
|
||||
data: WorkflowRunData
|
||||
236
api/controllers/openapi/account.py
Normal file
236
api/controllers/openapi/account.py
Normal file
@ -0,0 +1,236 @@
|
||||
"""User-scoped account endpoints. /account is the bearer-authed
|
||||
identity read; /account/sessions and /account/sessions/<id> manage
|
||||
the user's active OAuth tokens.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import and_, select, update
|
||||
from werkzeug.exceptions import BadRequest, NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT, PaginationEnvelope
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
TOKEN_CACHE_KEY_FMT,
|
||||
AuthContext,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from libs.rate_limit import (
|
||||
LIMIT_ME_PER_ACCOUNT,
|
||||
LIMIT_ME_PER_EMAIL,
|
||||
enforce,
|
||||
)
|
||||
from models import Account, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/account")
|
||||
class AccountApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
|
||||
else:
|
||||
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
|
||||
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email,
|
||||
"subject_issuer": ctx.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
}
|
||||
|
||||
account = (
|
||||
db.session.query(Account).filter(Account.id == ctx.account_id).one_or_none() if ctx.account_id else None
|
||||
)
|
||||
memberships = _load_memberships(ctx.account_id) if ctx.account_id else []
|
||||
default_ws_id = _pick_default_workspace(memberships)
|
||||
|
||||
return {
|
||||
"subject_type": ctx.subject_type,
|
||||
"subject_email": ctx.subject_email or (account.email if account else None),
|
||||
"account": _account_payload(account) if account else None,
|
||||
"workspaces": [_workspace_payload(m) for m in memberships],
|
||||
"default_workspace_id": default_ws_id,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/self")
|
||||
class AccountSessionsSelfApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
_revoke_token_by_id(str(ctx.token_id))
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions")
|
||||
class AccountSessionsApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
now = datetime.now(UTC)
|
||||
page = int(request.args.get("page", "1"))
|
||||
limit = min(int(request.args.get("limit", "100")), MAX_PAGE_LIMIT)
|
||||
|
||||
all_rows = db.session.execute(
|
||||
select(
|
||||
OAuthAccessToken.id,
|
||||
OAuthAccessToken.prefix,
|
||||
OAuthAccessToken.client_id,
|
||||
OAuthAccessToken.device_label,
|
||||
OAuthAccessToken.created_at,
|
||||
OAuthAccessToken.last_used_at,
|
||||
OAuthAccessToken.expires_at,
|
||||
)
|
||||
.where(
|
||||
and_(
|
||||
*_subject_match(ctx),
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
OAuthAccessToken.token_hash.is_not(None),
|
||||
OAuthAccessToken.expires_at > now,
|
||||
)
|
||||
)
|
||||
.order_by(OAuthAccessToken.created_at.desc())
|
||||
).all()
|
||||
|
||||
total = len(all_rows)
|
||||
sliced = all_rows[(page - 1) * limit : page * limit]
|
||||
|
||||
items = [
|
||||
{
|
||||
"id": str(r.id),
|
||||
"prefix": r.prefix,
|
||||
"client_id": r.client_id,
|
||||
"device_label": r.device_label,
|
||||
"created_at": _iso(r.created_at),
|
||||
"last_used_at": _iso(r.last_used_at),
|
||||
"expires_at": _iso(r.expires_at),
|
||||
}
|
||||
for r in sliced
|
||||
]
|
||||
|
||||
return (
|
||||
PaginationEnvelope.build(page=page, limit=limit, total=total, items=items).model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
@openapi_ns.route("/account/sessions/<string:session_id>")
|
||||
class AccountSessionByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def delete(self, session_id: str):
|
||||
ctx = g.auth_ctx
|
||||
_require_oauth_subject(ctx)
|
||||
|
||||
# Subject-match guard. 404 (not 403) on cross-subject so the
|
||||
# endpoint doesn't leak token IDs that belong to other subjects.
|
||||
owns = db.session.execute(
|
||||
select(OAuthAccessToken.id).where(
|
||||
and_(
|
||||
OAuthAccessToken.id == session_id,
|
||||
*_subject_match(ctx),
|
||||
)
|
||||
)
|
||||
).first()
|
||||
if owns is None:
|
||||
raise NotFound("session not found")
|
||||
|
||||
_revoke_token_by_id(session_id)
|
||||
return {"status": "revoked"}, 200
|
||||
|
||||
|
||||
def _subject_match(ctx: AuthContext) -> tuple:
|
||||
"""Where-clauses that scope a query to the bearer's subject. Works
|
||||
for both account (account_id) and external_sso (email + issuer).
|
||||
"""
|
||||
if ctx.subject_type == SubjectType.ACCOUNT:
|
||||
return (OAuthAccessToken.account_id == str(ctx.account_id),)
|
||||
return (
|
||||
OAuthAccessToken.subject_email == ctx.subject_email,
|
||||
OAuthAccessToken.subject_issuer == ctx.subject_issuer,
|
||||
OAuthAccessToken.account_id.is_(None),
|
||||
)
|
||||
|
||||
|
||||
def _require_oauth_subject(ctx: AuthContext) -> None:
|
||||
if not ctx.source.startswith("oauth"):
|
||||
raise BadRequest(
|
||||
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
|
||||
)
|
||||
|
||||
|
||||
def _revoke_token_by_id(token_id: str) -> None:
|
||||
# Snapshot pre-revoke hash for cache invalidation; UPDATE WHERE
|
||||
# makes double-revoke idempotent.
|
||||
row = (
|
||||
db.session.query(OAuthAccessToken.token_hash)
|
||||
.filter(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
pre_revoke_hash = row[0] if row else None
|
||||
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(
|
||||
OAuthAccessToken.id == token_id,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
db.session.execute(stmt)
|
||||
db.session.commit()
|
||||
|
||||
if pre_revoke_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=pre_revoke_hash))
|
||||
|
||||
|
||||
def _iso(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.replace(tzinfo=UTC)
|
||||
return dt.isoformat().replace("+00:00", "Z")
|
||||
|
||||
|
||||
def _load_memberships(account_id):
|
||||
return (
|
||||
db.session.query(TenantAccountJoin, Tenant)
|
||||
.join(Tenant, Tenant.id == TenantAccountJoin.tenant_id)
|
||||
.filter(TenantAccountJoin.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _pick_default_workspace(memberships) -> str | None:
|
||||
if not memberships:
|
||||
return None
|
||||
for join, tenant in memberships:
|
||||
if getattr(join, "current", False):
|
||||
return str(tenant.id)
|
||||
return str(memberships[0][1].id)
|
||||
|
||||
|
||||
def _workspace_payload(row) -> dict:
|
||||
join, tenant = row
|
||||
return {"id": str(tenant.id), "name": tenant.name, "role": getattr(join, "role", "")}
|
||||
|
||||
|
||||
def _account_payload(account) -> dict:
|
||||
return {"id": str(account.id), "email": account.email, "name": account.name}
|
||||
198
api/controllers/openapi/app_run.py
Normal file
198
api/controllers/openapi/app_run.py
Normal file
@ -0,0 +1,198 @@
|
||||
"""POST /openapi/v1/apps/<app_id>/run — mode-agnostic runner."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable, Iterator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Literal
|
||||
from uuid import UUID
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError, field_validator
|
||||
from werkzeug.exceptions import BadRequest, HTTPException, InternalServerError, NotFound, UnprocessableEntity
|
||||
|
||||
import services
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._audit import emit_app_run
|
||||
from controllers.openapi._models import (
|
||||
ChatMessageResponse,
|
||||
CompletionMessageResponse,
|
||||
WorkflowRunResponse,
|
||||
)
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
from controllers.service_api.app.error import (
|
||||
AppUnavailableError,
|
||||
CompletionRequestError,
|
||||
ConversationCompletedError,
|
||||
ProviderModelCurrentlyNotSupportError,
|
||||
ProviderNotInitializeError,
|
||||
ProviderQuotaExceededError,
|
||||
)
|
||||
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.errors.error import (
|
||||
ModelCurrentlyNotSupportError,
|
||||
ProviderTokenNotInitError,
|
||||
QuotaExceededError,
|
||||
)
|
||||
from graphon.model_runtime.errors.invoke import InvokeError
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.oauth_bearer import Scope
|
||||
from models.model import App, AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import (
|
||||
IsDraftWorkflowError,
|
||||
WorkflowIdFormatError,
|
||||
WorkflowNotFoundError,
|
||||
)
|
||||
from services.errors.llm import InvokeRateLimitError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppRunRequest(BaseModel):
|
||||
inputs: dict[str, Any]
|
||||
query: str | None = None
|
||||
files: list[dict[str, Any]] | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
conversation_id: UUIDStrOrEmpty | None = None
|
||||
auto_generate_name: bool = True
|
||||
workflow_id: str | None = None
|
||||
|
||||
@field_validator("conversation_id", mode="before")
|
||||
@classmethod
|
||||
def _normalize_conv(cls, value: str | UUID | None) -> str | None:
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return helper.uuid_value(value)
|
||||
except ValueError as exc:
|
||||
raise ValueError("conversation_id must be a valid UUID") from exc
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _translate_service_errors() -> Iterator[None]:
|
||||
try:
|
||||
yield
|
||||
except WorkflowNotFoundError as ex:
|
||||
raise NotFound(str(ex))
|
||||
except (IsDraftWorkflowError, WorkflowIdFormatError) as ex:
|
||||
raise BadRequest(str(ex))
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
except services.errors.conversation.ConversationCompletedError:
|
||||
raise ConversationCompletedError()
|
||||
except services.errors.app_model_config.AppModelConfigBrokenError:
|
||||
logger.exception("App model config broken.")
|
||||
raise AppUnavailableError()
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeRateLimitError as ex:
|
||||
raise InvokeRateLimitHttpError(ex.description)
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
def _unpack_blocking(response: Any) -> Mapping[str, Any]:
|
||||
if isinstance(response, tuple):
|
||||
response = response[0]
|
||||
if not isinstance(response, Mapping):
|
||||
raise InternalServerError("blocking generate returned non-mapping response")
|
||||
return response
|
||||
|
||||
|
||||
def _generate(app: App, caller: Any, args: dict[str, Any], streaming: bool):
|
||||
return AppGenerateService.generate(
|
||||
app_model=app,
|
||||
user=caller,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.OPENAPI,
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
|
||||
def _run_chat(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if not payload.query or not payload.query.strip():
|
||||
raise UnprocessableEntity("query_required_for_chat")
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, ChatMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_completion(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
args["auto_generate_name"] = False
|
||||
args.setdefault("query", "")
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, CompletionMessageResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
def _run_workflow(app: App, caller: Any, payload: AppRunRequest, streaming: bool):
|
||||
if payload.query is not None:
|
||||
raise UnprocessableEntity("query_not_supported_for_workflow")
|
||||
args = payload.model_dump(exclude={"query", "conversation_id", "auto_generate_name"}, exclude_none=True)
|
||||
with _translate_service_errors():
|
||||
response = _generate(app, caller, args, streaming)
|
||||
if streaming:
|
||||
return response, None
|
||||
return None, WorkflowRunResponse.model_validate(_unpack_blocking(response)).model_dump(mode="json")
|
||||
|
||||
|
||||
_DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest, bool], tuple[Any, dict[str, Any] | None]]] = {
|
||||
AppMode.CHAT: _run_chat,
|
||||
AppMode.AGENT_CHAT: _run_chat,
|
||||
AppMode.ADVANCED_CHAT: _run_chat,
|
||||
AppMode.COMPLETION: _run_completion,
|
||||
AppMode.WORKFLOW: _run_workflow,
|
||||
}
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/run")
|
||||
class AppRunApi(Resource):
|
||||
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
|
||||
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
|
||||
body = request.get_json(silent=True) or {}
|
||||
body.pop("user", None)
|
||||
try:
|
||||
payload = AppRunRequest.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
handler = _DISPATCH.get(app_model.mode)
|
||||
if handler is None:
|
||||
raise UnprocessableEntity("mode_not_runnable")
|
||||
|
||||
streaming = payload.response_mode == "streaming"
|
||||
try:
|
||||
stream_obj, blocking_body = handler(app_model, caller, payload, streaming)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
emit_app_run(
|
||||
app_id=app_model.id,
|
||||
tenant_id=app_model.tenant_id,
|
||||
caller_kind=caller_kind,
|
||||
mode=str(app_model.mode),
|
||||
)
|
||||
|
||||
if streaming:
|
||||
return helper.compact_generate_response(stream_obj)
|
||||
return blocking_body, 200
|
||||
315
api/controllers/openapi/apps.py
Normal file
315
api/controllers/openapi/apps.py
Normal file
@ -0,0 +1,315 @@
|
||||
"""GET /openapi/v1/apps and per-app reads.
|
||||
|
||||
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
|
||||
is last → outermost → sets `g.auth_ctx` before `require_scope` reads it.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid as _uuid
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import g, request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError, field_validator
|
||||
from werkzeug.exceptions import Conflict, NotFound, UnprocessableEntity
|
||||
|
||||
from controllers.common.fields import Parameters
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema, resolve_app_config
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppDescribeInfo,
|
||||
AppDescribeResponse,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
require_workspace_member,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.app_service import AppService
|
||||
from services.tag_service import TagService
|
||||
|
||||
_APPS_READ_DECORATORS = [
|
||||
require_scope(Scope.APPS_READ),
|
||||
validate_bearer(accept=ACCEPT_USER_ANY),
|
||||
]
|
||||
|
||||
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
|
||||
|
||||
|
||||
class AppDescribeQuery(BaseModel):
|
||||
"""`?fields=` allow-list for GET /apps/<id>/describe.
|
||||
|
||||
Empty / omitted → all blocks. Unknown member → ValidationError → 422.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
fields: set[str] | None = None
|
||||
workspace_id: str | None = None
|
||||
|
||||
@field_validator("workspace_id", mode="before")
|
||||
@classmethod
|
||||
def _validate_workspace_id(cls, v: object) -> str | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("workspace_id must be a string")
|
||||
try:
|
||||
_uuid.UUID(v)
|
||||
except ValueError:
|
||||
raise ValueError("workspace_id must be a valid UUID")
|
||||
return v
|
||||
|
||||
@field_validator("fields", mode="before")
|
||||
@classmethod
|
||||
def _parse_fields(cls, v: object) -> set[str] | None:
|
||||
if v is None or v == "":
|
||||
return None
|
||||
if not isinstance(v, str):
|
||||
raise ValueError("fields must be a comma-separated string")
|
||||
members = {m.strip() for m in v.split(",") if m.strip()}
|
||||
unknown = members - _ALLOWED_DESCRIBE_FIELDS
|
||||
if unknown:
|
||||
raise ValueError(f"unknown field(s): {sorted(unknown)}")
|
||||
return members
|
||||
|
||||
|
||||
_EMPTY_PARAMETERS: dict[str, Any] = {
|
||||
"opening_statement": None,
|
||||
"suggested_questions": [],
|
||||
"user_input_form": [],
|
||||
"file_upload": None,
|
||||
"system_parameters": {},
|
||||
}
|
||||
|
||||
|
||||
class AppReadResource(Resource):
|
||||
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
|
||||
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
raise NotFound("app not found")
|
||||
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(app_id)
|
||||
is_uuid = True
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
is_uuid = False
|
||||
|
||||
if is_uuid:
|
||||
app = db.session.get(App, str(parsed_uuid)) # normalised dashed form
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
else:
|
||||
if not workspace_id:
|
||||
raise UnprocessableEntity("workspace_id is required for name-based lookup")
|
||||
matches = list(
|
||||
db.session.execute(
|
||||
sa.select(App).where(
|
||||
App.name == app_id,
|
||||
App.tenant_id == workspace_id,
|
||||
App.status == "normal",
|
||||
)
|
||||
).scalars()
|
||||
)
|
||||
if len(matches) == 0:
|
||||
raise NotFound("app not found")
|
||||
if len(matches) > 1:
|
||||
lines = [f"app name {app_id!r} is ambiguous — re-run with a UUID:\n\n"]
|
||||
lines.append(f" {'ID':<36} {'MODE':<12} NAME\n")
|
||||
for m in matches:
|
||||
lines.append(f" {str(m.id):<36} {str(m.mode.value):<12} {m.name}\n")
|
||||
raise Conflict("".join(lines))
|
||||
app = matches[0]
|
||||
|
||||
require_workspace_member(ctx, str(app.tenant_id))
|
||||
return app, ctx
|
||||
|
||||
|
||||
def parameters_payload(app: App) -> dict:
|
||||
"""Mirrors service_api/app/app.py::AppParameterApi response body."""
|
||||
features_dict, user_input_form = resolve_app_config(app)
|
||||
parameters = get_parameters_from_feature_dict(features_dict=features_dict, user_input_form=user_input_form)
|
||||
return Parameters.model_validate(parameters).model_dump(mode="json")
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/<string:app_id>/describe")
|
||||
class AppDescribeApi(AppReadResource):
|
||||
def get(self, app_id: str):
|
||||
try:
|
||||
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
app, _ = self._load(app_id, workspace_id=query.workspace_id)
|
||||
|
||||
requested = query.fields
|
||||
want_info = requested is None or "info" in requested
|
||||
want_params = requested is None or "parameters" in requested
|
||||
want_schema = requested is None or "input_schema" in requested
|
||||
|
||||
info = (
|
||||
AppDescribeInfo(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
mode=app.mode,
|
||||
description=app.description,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
author=app.author_name,
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
service_api_enabled=bool(app.enable_api),
|
||||
)
|
||||
if want_info
|
||||
else None
|
||||
)
|
||||
|
||||
parameters: dict[str, Any] | None = None
|
||||
input_schema: dict[str, Any] | None = None
|
||||
if want_params:
|
||||
try:
|
||||
parameters = parameters_payload(app)
|
||||
except AppUnavailableError:
|
||||
parameters = dict(_EMPTY_PARAMETERS)
|
||||
if want_schema:
|
||||
try:
|
||||
input_schema = build_input_schema(app)
|
||||
except AppUnavailableError:
|
||||
input_schema = dict(EMPTY_INPUT_SCHEMA)
|
||||
|
||||
return (
|
||||
AppDescribeResponse(
|
||||
info=info,
|
||||
parameters=parameters,
|
||||
input_schema=input_schema,
|
||||
).model_dump(mode="json", exclude_none=False),
|
||||
200,
|
||||
)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
"""`mode` is a closed enum — unknown values 422 instead of silently-empty data."""
|
||||
|
||||
workspace_id: str
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
tag: str | None = Field(None, max_length=100)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps")
|
||||
class AppListApi(Resource):
|
||||
method_decorators = _APPS_READ_DECORATORS
|
||||
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return PaginationEnvelope[AppListRow].build(page=1, limit=0, total=0, items=[]).model_dump(mode="json"), 200
|
||||
|
||||
try:
|
||||
query = AppListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
workspace_id = query.workspace_id
|
||||
require_workspace_member(ctx, workspace_id)
|
||||
|
||||
empty = (
|
||||
PaginationEnvelope[AppListRow]
|
||||
.build(page=query.page, limit=query.limit, total=0, items=[])
|
||||
.model_dump(mode="json"),
|
||||
200,
|
||||
)
|
||||
|
||||
if query.name:
|
||||
try:
|
||||
parsed_uuid = _uuid.UUID(query.name)
|
||||
except ValueError:
|
||||
parsed_uuid = None
|
||||
else:
|
||||
parsed_uuid = None
|
||||
|
||||
if parsed_uuid is not None:
|
||||
app = db.session.get(App, str(parsed_uuid))
|
||||
if not app or app.status != "normal" or str(app.tenant_id) != workspace_id:
|
||||
return empty
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
item = AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[{"name": t.name} for t in app.tags],
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=getattr(app, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
env = PaginationEnvelope[AppListRow].build(page=1, limit=1, total=1, items=[item])
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
tag_ids: list[str] | None = None
|
||||
if query.tag:
|
||||
tags = TagService.get_tag_by_tag_name("app", workspace_id, query.tag)
|
||||
if not tags:
|
||||
return empty
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"mode": query.mode.value if query.mode else "",
|
||||
"name": query.name,
|
||||
"status": "normal",
|
||||
}
|
||||
if tag_ids:
|
||||
args["tag_ids"] = tag_ids
|
||||
|
||||
pagination = AppService().get_paginate_apps(ctx.account_id, workspace_id, args)
|
||||
if pagination is None:
|
||||
return empty
|
||||
|
||||
tenant_name: str | None = None
|
||||
if pagination.items:
|
||||
tenant_name = db.session.execute(
|
||||
sa.select(Tenant.name).where(Tenant.id == workspace_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
items = [
|
||||
AppListRow(
|
||||
id=str(r.id),
|
||||
name=r.name,
|
||||
description=r.description,
|
||||
mode=r.mode,
|
||||
tags=[{"name": t.name} for t in r.tags],
|
||||
updated_at=r.updated_at.isoformat() if r.updated_at else None,
|
||||
created_by_name=getattr(r, "author_name", None),
|
||||
workspace_id=str(workspace_id),
|
||||
workspace_name=tenant_name,
|
||||
)
|
||||
for r in pagination.items
|
||||
]
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=int(pagination.total), items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
101
api/controllers/openapi/apps_permitted.py
Normal file
101
api/controllers/openapi/apps_permitted.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""GET /openapi/v1/apps/permitted — external-subject app discovery (EE only)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError
|
||||
from werkzeug.exceptions import UnprocessableEntity
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from controllers.openapi._models import (
|
||||
MAX_PAGE_LIMIT,
|
||||
AppListRow,
|
||||
PaginationEnvelope,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.device_flow_security import enterprise_only
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_EXT_SSO,
|
||||
Scope,
|
||||
require_scope,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import App, Tenant
|
||||
from models.model import AppMode
|
||||
from services.enterprise.app_permitted_service import list_permitted_apps
|
||||
|
||||
|
||||
class AppPermittedListQuery(BaseModel):
|
||||
"""Strict (`extra='forbid'`) — rejects `workspace_id`/`tag`/etc. that are valid on /apps but not here."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
page: int = Field(1, ge=1)
|
||||
limit: int = Field(20, ge=1, le=MAX_PAGE_LIMIT)
|
||||
mode: AppMode | None = None
|
||||
name: str | None = Field(None, max_length=200)
|
||||
|
||||
|
||||
@openapi_ns.route("/apps/permitted")
|
||||
class AppPermittedListApi(Resource):
|
||||
method_decorators = [
|
||||
require_scope(Scope.APPS_READ_PERMITTED),
|
||||
validate_bearer(accept=ACCEPT_USER_EXT_SSO),
|
||||
enterprise_only,
|
||||
]
|
||||
|
||||
def get(self):
|
||||
try:
|
||||
query = AppPermittedListQuery.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise UnprocessableEntity(exc.json())
|
||||
|
||||
page_result = list_permitted_apps(
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
mode=query.mode.value if query.mode else None,
|
||||
name=query.name,
|
||||
)
|
||||
|
||||
if not page_result.app_ids:
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=[]
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
|
||||
apps_by_id = {
|
||||
str(a.id): a
|
||||
for a in db.session.execute(sa.select(App).where(App.id.in_(page_result.app_ids))).scalars().all()
|
||||
}
|
||||
tenant_ids = list({a.tenant_id for a in apps_by_id.values()})
|
||||
tenants_by_id = {
|
||||
str(t.id): t for t in db.session.execute(sa.select(Tenant).where(Tenant.id.in_(tenant_ids))).scalars().all()
|
||||
}
|
||||
|
||||
items: list[AppListRow] = []
|
||||
for app_id in page_result.app_ids:
|
||||
app = apps_by_id.get(app_id)
|
||||
if not app or app.status != "normal":
|
||||
continue
|
||||
tenant = tenants_by_id.get(str(app.tenant_id))
|
||||
items.append(
|
||||
AppListRow(
|
||||
id=str(app.id),
|
||||
name=app.name,
|
||||
description=app.description,
|
||||
mode=app.mode,
|
||||
tags=[], # tenant-scoped; not surfaced cross-tenant
|
||||
updated_at=app.updated_at.isoformat() if app.updated_at else None,
|
||||
created_by_name=None, # cross-tenant author leak prevention
|
||||
workspace_id=str(app.tenant_id),
|
||||
workspace_name=tenant.name if tenant else None,
|
||||
)
|
||||
)
|
||||
|
||||
# total/has_more reflect the EE-side allow-list; len(items) may be < limit when local rows are dropped.
|
||||
env = PaginationEnvelope[AppListRow].build(
|
||||
page=query.page, limit=query.limit, total=page_result.total, items=items
|
||||
)
|
||||
return env.model_dump(mode="json"), 200
|
||||
3
api/controllers/openapi/auth/__init__.py
Normal file
3
api/controllers/openapi/auth/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
|
||||
|
||||
__all__ = ["OAUTH_BEARER_PIPELINE"]
|
||||
43
api/controllers/openapi/auth/composition.py
Normal file
43
api/controllers/openapi/auth/composition.py
Normal file
@ -0,0 +1,43 @@
|
||||
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
|
||||
|
||||
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
|
||||
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
|
||||
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
|
||||
inline — they don't need `AppAuthzCheck`/`CallerMount`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
AppAuthzStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
return AclStrategy()
|
||||
return MembershipStrategy()
|
||||
|
||||
|
||||
OAUTH_BEARER_PIPELINE = Pipeline(
|
||||
BearerCheck(),
|
||||
ScopeCheck(),
|
||||
AppResolver(),
|
||||
WorkspaceMembershipCheck(),
|
||||
AppAuthzCheck(_resolve_app_authz_strategy),
|
||||
CallerMount(AccountMounter(), EndUserMounter()),
|
||||
)
|
||||
46
api/controllers/openapi/auth/context.py
Normal file
46
api/controllers/openapi/auth/context.py
Normal file
@ -0,0 +1,46 @@
|
||||
"""Mutable per-request context for the openapi auth pipeline.
|
||||
|
||||
Every field starts None / empty and is filled in by a step. The pipeline
|
||||
is the only thing that should construct or mutate Context — handlers
|
||||
read populated values via the decorator's kwargs unpacking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Literal, Protocol
|
||||
|
||||
from flask import Request
|
||||
|
||||
from libs.oauth_bearer import Scope, SubjectType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
@dataclass
|
||||
class Context:
|
||||
request: Request
|
||||
required_scope: Scope
|
||||
subject_type: SubjectType | None = None
|
||||
subject_email: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
account_id: uuid.UUID | None = None
|
||||
scopes: frozenset[Scope] = field(default_factory=frozenset)
|
||||
token_id: uuid.UUID | None = None
|
||||
token_hash: str | None = None
|
||||
cached_verified_tenants: dict[str, bool] | None = None
|
||||
source: str | None = None
|
||||
expires_at: datetime | None = None
|
||||
app: App | None = None
|
||||
tenant: Tenant | None = None
|
||||
caller: object | None = None
|
||||
caller_kind: Literal["account", "end_user"] | None = None
|
||||
|
||||
|
||||
class Step(Protocol):
|
||||
"""One responsibility. Mutate ctx or raise to short-circuit."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None: ...
|
||||
41
api/controllers/openapi/auth/pipeline.py
Normal file
41
api/controllers/openapi/auth/pipeline.py
Normal file
@ -0,0 +1,41 @@
|
||||
"""Pipeline IS the auth scheme.
|
||||
|
||||
`Pipeline.guard(scope=…)` is the only attachment point for endpoints —
|
||||
that is the design lock-in: forgetting an auth layer is structurally
|
||||
impossible because there is no "sometimes wrap, sometimes don't" choice.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from flask import request
|
||||
|
||||
from controllers.openapi.auth.context import Context, Step
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, *steps: Step) -> None:
|
||||
self._steps = steps
|
||||
|
||||
def run(self, ctx: Context) -> None:
|
||||
for step in self._steps:
|
||||
step(ctx)
|
||||
|
||||
def guard(self, *, scope: Scope):
|
||||
def decorator(view):
|
||||
@wraps(view)
|
||||
def decorated(*args, **kwargs):
|
||||
ctx = Context(request=request, required_scope=scope)
|
||||
self.run(ctx)
|
||||
kwargs.update(
|
||||
app_model=ctx.app,
|
||||
caller=ctx.caller,
|
||||
caller_kind=ctx.caller_kind,
|
||||
)
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
131
api/controllers/openapi/auth/steps.py
Normal file
131
api/controllers/openapi/auth/steps.py
Normal file
@ -0,0 +1,131 @@
|
||||
"""Pipeline steps. Each is one responsibility.
|
||||
|
||||
`BearerCheck` is the only step that touches the token registry; downstream
|
||||
steps see only the populated `Context`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
InvalidBearerError,
|
||||
Scope,
|
||||
SubjectType,
|
||||
_extract_bearer, # type: ignore[attr-defined]
|
||||
check_workspace_membership,
|
||||
get_authenticator,
|
||||
)
|
||||
from models import App, Tenant, TenantStatus
|
||||
|
||||
|
||||
class BearerCheck:
|
||||
"""Resolve bearer → populate identity fields. Rate-limit is enforced
|
||||
inside `BearerAuthenticator.authenticate`, so no separate step here."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
token = _extract_bearer(ctx.request)
|
||||
if not token:
|
||||
raise Unauthorized("bearer required")
|
||||
|
||||
try:
|
||||
authn = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
ctx.subject_type = authn.subject_type
|
||||
ctx.subject_email = authn.subject_email
|
||||
ctx.subject_issuer = authn.subject_issuer
|
||||
ctx.account_id = authn.account_id
|
||||
ctx.scopes = frozenset(authn.scopes)
|
||||
ctx.source = authn.source
|
||||
ctx.token_id = authn.token_id
|
||||
ctx.expires_at = authn.expires_at
|
||||
ctx.token_hash = authn.token_hash
|
||||
ctx.cached_verified_tenants = dict(authn.verified_tenants)
|
||||
|
||||
|
||||
class ScopeCheck:
|
||||
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
|
||||
return
|
||||
raise Forbidden("insufficient_scope")
|
||||
|
||||
|
||||
class AppResolver:
|
||||
"""Read app_id from request.view_args, populate ctx.app + ctx.tenant.
|
||||
|
||||
Every endpoint using the OAuth bearer pipeline must declare
|
||||
``<string:app_id>`` in its route — that is the design lock-in (no body /
|
||||
header coupling).
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
app_id = (ctx.request.view_args or {}).get("app_id")
|
||||
if not app_id:
|
||||
raise BadRequest("app_id is required in path")
|
||||
app = db.session.get(App, app_id)
|
||||
if not app or app.status != "normal":
|
||||
raise NotFound("app not found")
|
||||
if not app.enable_api:
|
||||
raise Forbidden("service_api_disabled")
|
||||
tenant = db.session.get(Tenant, app.tenant_id)
|
||||
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden("workspace unavailable")
|
||||
ctx.app, ctx.tenant = app, tenant
|
||||
|
||||
|
||||
class WorkspaceMembershipCheck:
|
||||
"""Layer 0 — workspace membership gate.
|
||||
|
||||
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
|
||||
(dfoa_) only — SSO subjects skip.
|
||||
"""
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT:
|
||||
return
|
||||
if ctx.account_id is None or ctx.tenant is None:
|
||||
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
|
||||
if ctx.token_hash is None:
|
||||
raise Unauthorized("token_hash unset — BearerCheck did not run")
|
||||
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=ctx.tenant.id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.cached_verified_tenants or {},
|
||||
)
|
||||
|
||||
|
||||
class AppAuthzCheck:
|
||||
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
|
||||
self._resolve = resolve_strategy
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if not self._resolve().authorize(ctx):
|
||||
raise Forbidden("subject_no_app_access")
|
||||
|
||||
|
||||
class CallerMount:
|
||||
def __init__(self, *mounters: CallerMounter) -> None:
|
||||
self._mounters = mounters
|
||||
|
||||
def __call__(self, ctx: Context) -> None:
|
||||
if ctx.subject_type is None:
|
||||
raise Unauthorized("subject_type unset — BearerCheck did not run")
|
||||
for m in self._mounters:
|
||||
if m.applies_to(ctx.subject_type):
|
||||
m.mount(ctx)
|
||||
return
|
||||
raise Unauthorized("no caller mounter for subject type")
|
||||
115
api/controllers/openapi/auth/strategies.py
Normal file
115
api/controllers/openapi/auth/strategies.py
Normal file
@ -0,0 +1,115 @@
|
||||
"""Strategy classes for the openapi auth pipeline.
|
||||
|
||||
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
|
||||
vary along independent axes; each strategy is one class so the pipeline
|
||||
composition stays a flat list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Protocol
|
||||
|
||||
from flask import current_app
|
||||
from flask_login import user_logged_in
|
||||
from sqlalchemy import select
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from models import Account, TenantAccountJoin
|
||||
from services.end_user_service import EndUserService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
|
||||
|
||||
class AppAuthzStrategy(Protocol):
|
||||
def authorize(self, ctx: Context) -> bool: ...
|
||||
|
||||
|
||||
class AclStrategy:
|
||||
"""Per-app ACL via the workspace-auth inner API.
|
||||
|
||||
Used when webapp-auth is enabled (EE deployment). The inner-API
|
||||
allowlist is the source of truth.
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_email is None or ctx.app is None:
|
||||
return False
|
||||
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
|
||||
user_id=ctx.subject_email,
|
||||
app_id=ctx.app.id,
|
||||
)
|
||||
|
||||
|
||||
class MembershipStrategy:
|
||||
"""Tenant-membership fallback.
|
||||
|
||||
Used when webapp-auth is disabled (CE deployment). Account-bearing
|
||||
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
|
||||
denied (it requires the webapp-auth surface).
|
||||
"""
|
||||
|
||||
def authorize(self, ctx: Context) -> bool:
|
||||
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
|
||||
return False
|
||||
if ctx.tenant is None:
|
||||
return False
|
||||
return _has_tenant_membership(ctx.account_id, ctx.tenant.id)
|
||||
|
||||
|
||||
def _has_tenant_membership(account_id: uuid.UUID | str | None, tenant_id: str) -> bool:
|
||||
if not account_id:
|
||||
return False
|
||||
row = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
return row is not None
|
||||
|
||||
|
||||
def _login_as(user) -> None:
|
||||
"""Set Flask-Login request user so downstream services see the caller."""
|
||||
current_app.login_manager._update_request_context_with_user(user)
|
||||
user_logged_in.send(current_app._get_current_object(), user=user)
|
||||
|
||||
|
||||
class CallerMounter(Protocol):
|
||||
def applies_to(self, subject_type: SubjectType) -> bool: ...
|
||||
|
||||
def mount(self, ctx: Context) -> None: ...
|
||||
|
||||
|
||||
class AccountMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.ACCOUNT
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.account_id is None:
|
||||
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
|
||||
account = db.session.get(Account, ctx.account_id)
|
||||
if account is None:
|
||||
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
|
||||
account.current_tenant = ctx.tenant
|
||||
_login_as(account)
|
||||
ctx.caller, ctx.caller_kind = account, "account"
|
||||
|
||||
|
||||
class EndUserMounter:
|
||||
def applies_to(self, subject_type: SubjectType) -> bool:
|
||||
return subject_type == SubjectType.EXTERNAL_SSO
|
||||
|
||||
def mount(self, ctx: Context) -> None:
|
||||
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
|
||||
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id=ctx.tenant.id,
|
||||
app_id=ctx.app.id,
|
||||
user_id=ctx.subject_email,
|
||||
)
|
||||
_login_as(end_user)
|
||||
ctx.caller, ctx.caller_kind = end_user, "end_user"
|
||||
9
api/controllers/openapi/index.py
Normal file
9
api/controllers/openapi/index.py
Normal file
@ -0,0 +1,9 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
|
||||
|
||||
@openapi_ns.route("/_health")
|
||||
class HealthApi(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
392
api/controllers/openapi/oauth_device.py
Normal file
392
api/controllers/openapi/oauth_device.py
Normal file
@ -0,0 +1,392 @@
|
||||
"""Device-flow endpoints under /openapi/v1/oauth/device/*. Two
|
||||
sub-groups in one module:
|
||||
|
||||
Protocol (RFC 8628, public + rate-limited):
|
||||
POST /oauth/device/code
|
||||
POST /oauth/device/token
|
||||
GET /oauth/device/lookup
|
||||
|
||||
Approval (account branch, console-cookie authed):
|
||||
POST /oauth/device/approve
|
||||
POST /oauth/device/deny
|
||||
|
||||
SSO branch lives in oauth_device_sso.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import login_required
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.helper import extract_remote_ip
|
||||
from libs.login import current_account_with_tenant
|
||||
from libs.oauth_bearer import SubjectType, bearer_feature_required
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_CONSOLE,
|
||||
LIMIT_DEVICE_CODE_PER_IP,
|
||||
LIMIT_LOOKUP_PUBLIC,
|
||||
rate_limit,
|
||||
)
|
||||
from services.oauth_device_flow import (
|
||||
ACCOUNT_ISSUER_SENTINEL,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
PREFIX_OAUTH_ACCOUNT,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
SlowDownDecision,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Request / query schemas
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class DeviceCodeRequest(BaseModel):
|
||||
client_id: str
|
||||
device_label: str
|
||||
|
||||
|
||||
class DevicePollRequest(BaseModel):
|
||||
device_code: str
|
||||
client_id: str
|
||||
|
||||
|
||||
class DeviceLookupQuery(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
class DeviceMutateRequest(BaseModel):
|
||||
user_code: str
|
||||
|
||||
|
||||
def _validate_json[M: BaseModel](model: type[M]) -> M:
|
||||
body = request.get_json(silent=True) or {}
|
||||
try:
|
||||
return model.model_validate(body)
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
def _validate_query[M: BaseModel](model: type[M]) -> M:
|
||||
try:
|
||||
return model.model_validate(request.args.to_dict(flat=True))
|
||||
except ValidationError as exc:
|
||||
raise BadRequest(str(exc))
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Protocol endpoints — RFC 8628 (public + per-IP rate limit)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/code")
|
||||
class OAuthDeviceCodeApi(Resource):
|
||||
@rate_limit(LIMIT_DEVICE_CODE_PER_IP)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceCodeRequest)
|
||||
client_id = payload.client_id
|
||||
device_label = payload.device_label
|
||||
|
||||
if client_id not in dify_config.OPENAPI_KNOWN_CLIENT_IDS:
|
||||
return {"error": "unsupported_client"}, 400
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
ip = extract_remote_ip(request)
|
||||
device_code, user_code, expires_in = store.start(client_id, device_label, created_ip=ip)
|
||||
|
||||
return {
|
||||
"device_code": device_code,
|
||||
"user_code": user_code,
|
||||
"verification_uri": _verification_uri(),
|
||||
"expires_in": expires_in,
|
||||
"interval": DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/token")
|
||||
class OAuthDeviceTokenApi(Resource):
|
||||
"""RFC 8628 poll."""
|
||||
|
||||
def post(self):
|
||||
payload = _validate_json(DevicePollRequest)
|
||||
device_code = payload.device_code
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
# slow_down beats every other branch — polling-too-fast clients
|
||||
# see only that response regardless of underlying state.
|
||||
if store.record_poll(device_code, DEFAULT_POLL_INTERVAL_SECONDS) is SlowDownDecision.SLOW_DOWN:
|
||||
return {"error": "slow_down"}, 400
|
||||
|
||||
state = store.load_by_device_code(device_code)
|
||||
if state is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if state.status is DeviceFlowStatus.PENDING:
|
||||
return {"error": "authorization_pending"}, 400
|
||||
|
||||
terminal = store.consume_on_poll(device_code)
|
||||
if terminal is None:
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
if terminal.status is DeviceFlowStatus.DENIED:
|
||||
return {"error": "access_denied"}, 400
|
||||
|
||||
poll_payload = terminal.poll_payload or {}
|
||||
if "token" not in poll_payload:
|
||||
logger.error("device_flow: approved state missing poll_payload for %s", device_code)
|
||||
return {"error": "expired_token"}, 400
|
||||
|
||||
_audit_cross_ip_if_needed(state)
|
||||
return poll_payload, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/lookup")
|
||||
class OAuthDeviceLookupApi(Resource):
|
||||
"""Read-only — public for pre-validate before login. user_code is
|
||||
high-entropy + short-TTL; per-IP rate limit blocks enumeration.
|
||||
"""
|
||||
|
||||
@rate_limit(LIMIT_LOOKUP_PUBLIC)
|
||||
def get(self):
|
||||
payload = _validate_query(DeviceLookupQuery)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": None}, 200
|
||||
|
||||
_device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"valid": False, "expires_in_remaining": 0, "client_id": state.client_id}, 200
|
||||
|
||||
return {
|
||||
"valid": True,
|
||||
"expires_in_remaining": DEVICE_FLOW_TTL_SECONDS,
|
||||
"client_id": state.client_id,
|
||||
}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Approval endpoints — account branch (cookie-authed)
|
||||
# =========================================================================
|
||||
|
||||
|
||||
_APPROVE_GUARD_KEY_FMT = "device_code:{code}:approving"
|
||||
_APPROVE_GUARD_TTL_SECONDS = 10
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/approve")
|
||||
class DeviceApproveApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
account, tenant = current_account_with_tenant()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
# SET NX guard — without it, two in-flight approves both pass
|
||||
# PENDING, both mint, and the second upsert silently rotates the
|
||||
# first caller into an already-revoked token.
|
||||
guard_key = _APPROVE_GUARD_KEY_FMT.format(code=device_code)
|
||||
if not redis_client.set(guard_key, "1", nx=True, ex=_APPROVE_GUARD_TTL_SECONDS):
|
||||
return {"error": "approve_in_progress"}, 409
|
||||
|
||||
try:
|
||||
ttl_days = oauth_ttl_days(tenant_id=tenant)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=account.email,
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
account_id=str(account.id),
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_ACCOUNT,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = _build_account_poll_payload(account, tenant, mint)
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=account.email,
|
||||
account_id=str(account.id),
|
||||
subject_issuer=ACCOUNT_ISSUER_SENTINEL,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
# Row minted but state vanished — roll forward; the orphan
|
||||
# token is revocable via auth devices list / Authorized Apps.
|
||||
logger.exception("device_flow: approve raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
finally:
|
||||
redis_client.delete(guard_key)
|
||||
|
||||
_emit_approve_audit(state, account, tenant, mint)
|
||||
return {"status": "approved"}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/oauth/device/deny")
|
||||
class DeviceDenyApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@bearer_feature_required
|
||||
@rate_limit(LIMIT_APPROVE_CONSOLE)
|
||||
def post(self):
|
||||
payload = _validate_json(DeviceMutateRequest)
|
||||
user_code = payload.user_code.strip().upper()
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
return {"error": "expired_or_unknown"}, 404
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
return {"error": "already_resolved"}, 409
|
||||
|
||||
try:
|
||||
store.deny(device_code)
|
||||
except (StateNotFoundError, InvalidTransitionError):
|
||||
logger.exception("device_flow: deny raced on %s", device_code)
|
||||
return {"error": "state_lost"}, 409
|
||||
|
||||
_emit_deny_audit(state)
|
||||
return {"status": "denied"}, 200
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Helpers
|
||||
# =========================================================================
|
||||
|
||||
|
||||
def _verification_uri() -> str:
|
||||
base = getattr(dify_config, "CONSOLE_WEB_URL", None)
|
||||
if base:
|
||||
return f"{base.rstrip('/')}/device"
|
||||
return f"{request.host_url.rstrip('/')}/device"
|
||||
|
||||
|
||||
def _audit_cross_ip_if_needed(state) -> None:
|
||||
poll_ip = extract_remote_ip(request)
|
||||
if state.created_ip and poll_ip and poll_ip != state.created_ip:
|
||||
logger.warning(
|
||||
"audit: oauth.device_code_cross_ip_poll token_id=%s creation_ip=%s poll_ip=%s",
|
||||
state.token_id,
|
||||
state.created_ip,
|
||||
poll_ip,
|
||||
extra={
|
||||
"audit": True,
|
||||
"token_id": state.token_id,
|
||||
"creation_ip": state.created_ip,
|
||||
"poll_ip": poll_ip,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_account_poll_payload(account, tenant, mint) -> dict:
|
||||
"""Pre-render the poll-response body so the unauthenticated poll
|
||||
handler doesn't re-query accounts/tenants for authz data.
|
||||
"""
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
rows = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.filter(TenantAccountJoin.account_id == account.id)
|
||||
.all()
|
||||
)
|
||||
workspaces = [{"id": str(t.id), "name": t.name, "role": getattr(m, "role", "")} for t, m in rows]
|
||||
# Prefer active session tenant → DB-flagged current join → first membership.
|
||||
default_ws_id = None
|
||||
if tenant and any(w["id"] == str(tenant) for w in workspaces):
|
||||
default_ws_id = str(tenant)
|
||||
if default_ws_id is None:
|
||||
for _t, m in rows:
|
||||
if getattr(m, "current", False):
|
||||
default_ws_id = str(m.tenant_id)
|
||||
break
|
||||
if default_ws_id is None and workspaces:
|
||||
default_ws_id = workspaces[0]["id"]
|
||||
|
||||
return {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"account": {"id": str(account.id), "email": account.email, "name": account.name},
|
||||
"workspaces": workspaces,
|
||||
"default_workspace_id": default_ws_id,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
|
||||
def _emit_approve_audit(state, account, tenant, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved token_id=%s subject=%s client_id=%s device_label=%s rotated=? expires_at=%s",
|
||||
mint.token_id,
|
||||
account.email,
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
mint.expires_at,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"token_id": str(mint.token_id),
|
||||
"subject_type": SubjectType.ACCOUNT,
|
||||
"subject_email": account.email,
|
||||
"account_id": str(account.id),
|
||||
"tenant_id": tenant,
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["full"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _emit_deny_audit(state) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_denied client_id=%s device_label=%s",
|
||||
state.client_id,
|
||||
state.device_label,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_denied",
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
},
|
||||
)
|
||||
287
api/controllers/openapi/oauth_device_sso.py
Normal file
287
api/controllers/openapi/oauth_device_sso.py
Normal file
@ -0,0 +1,287 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/*.
|
||||
EE-only. Browser flow:
|
||||
|
||||
GET /oauth/device/sso-initiate → 302 to IdP authorize URL
|
||||
GET /oauth/device/sso-complete → ACS callback, sets approval-grant cookie
|
||||
GET /oauth/device/approval-context → SPA reads cookie claims (idempotent)
|
||||
POST /oauth/device/approve-external → mints dfoe_ token + clears cookie
|
||||
|
||||
Function-based (raw @bp.route) rather than Resource classes because the
|
||||
handlers do redirects + cookie kwargs that don't fit the Resource shape.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
|
||||
from flask import jsonify, make_response, redirect, request
|
||||
from werkzeug.exceptions import (
|
||||
BadGateway,
|
||||
BadRequest,
|
||||
Conflict,
|
||||
Forbidden,
|
||||
NotFound,
|
||||
Unauthorized,
|
||||
)
|
||||
|
||||
from controllers.openapi import bp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs import jws
|
||||
from libs.device_flow_security import (
|
||||
APPROVAL_GRANT_COOKIE_NAME,
|
||||
ApprovalGrantClaims,
|
||||
approval_grant_cleared_cookie_kwargs,
|
||||
approval_grant_cookie_kwargs,
|
||||
consume_approval_grant_nonce,
|
||||
consume_sso_assertion_nonce,
|
||||
enterprise_only,
|
||||
mint_approval_grant,
|
||||
verify_approval_grant,
|
||||
)
|
||||
from libs.oauth_bearer import SubjectType
|
||||
from libs.rate_limit import (
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL,
|
||||
LIMIT_SSO_INITIATE_PER_IP,
|
||||
enforce,
|
||||
rate_limit,
|
||||
)
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.oauth_device_flow import (
|
||||
PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
DeviceFlowRedis,
|
||||
DeviceFlowStatus,
|
||||
InvalidTransitionError,
|
||||
StateNotFoundError,
|
||||
mint_oauth_token,
|
||||
oauth_ttl_days,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Matches DEVICE_FLOW_TTL_SECONDS so the signed state can't outlive the
|
||||
# device_code it references.
|
||||
STATE_ENVELOPE_TTL_SECONDS = 15 * 60
|
||||
|
||||
# Canonical sso-complete path. IdP-side ACS callback URL must point here.
|
||||
_SSO_COMPLETE_PATH = "/openapi/v1/oauth/device/sso-complete"
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-initiate", methods=["GET"])
|
||||
@enterprise_only
|
||||
@rate_limit(LIMIT_SSO_INITIATE_PER_IP)
|
||||
def sso_initiate():
|
||||
user_code = (request.args.get("user_code") or "").strip().upper()
|
||||
if not user_code:
|
||||
raise BadRequest("user_code required")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise BadRequest("invalid_user_code")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise BadRequest("invalid_user_code")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
signed_state = jws.sign(
|
||||
keyset,
|
||||
payload={
|
||||
"redirect_url": "",
|
||||
"app_code": "",
|
||||
"intent": "device_flow",
|
||||
"user_code": user_code,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
"return_to": "",
|
||||
"idp_callback_url": f"{request.host_url.rstrip('/')}{_SSO_COMPLETE_PATH}",
|
||||
},
|
||||
aud=jws.AUD_STATE_ENVELOPE,
|
||||
ttl_seconds=STATE_ENVELOPE_TTL_SECONDS,
|
||||
)
|
||||
|
||||
try:
|
||||
reply = EnterpriseService.initiate_device_flow_sso(signed_state)
|
||||
except Exception as e:
|
||||
logger.warning("sso-initiate: enterprise call failed: %s", e)
|
||||
raise BadGateway("sso_initiate_failed") from e
|
||||
|
||||
url = (reply or {}).get("url")
|
||||
if not url:
|
||||
raise BadGateway("sso_initiate_missing_url")
|
||||
|
||||
# Clear stale approval-grant — defends against cross-tab/back-button mixing.
|
||||
resp = redirect(url, code=302)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/sso-complete", methods=["GET"])
|
||||
@enterprise_only
|
||||
def sso_complete():
|
||||
blob = request.args.get("sso_assertion")
|
||||
if not blob:
|
||||
raise BadRequest("sso_assertion required")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
|
||||
try:
|
||||
claims = jws.verify(keyset, blob, expected_aud=jws.AUD_EXT_SUBJECT_ASSERTION)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("sso-complete: rejected assertion: %s", e)
|
||||
raise BadRequest("invalid_sso_assertion") from e
|
||||
|
||||
if not consume_sso_assertion_nonce(redis_client, claims.get("nonce", "")):
|
||||
raise BadRequest("invalid_sso_assertion")
|
||||
|
||||
user_code = (claims.get("user_code") or "").strip().upper()
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(user_code)
|
||||
if found is None:
|
||||
raise Conflict("user_code_not_pending")
|
||||
_, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
iss = request.host_url.rstrip("/")
|
||||
cookie_value, _ = mint_approval_grant(
|
||||
keyset=keyset,
|
||||
iss=iss,
|
||||
subject_email=claims["email"],
|
||||
subject_issuer=claims["issuer"],
|
||||
user_code=user_code,
|
||||
)
|
||||
|
||||
resp = redirect("/device?sso_verified=1", code=302)
|
||||
resp.set_cookie(**approval_grant_cookie_kwargs(cookie_value))
|
||||
return resp
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approval-context", methods=["GET"])
|
||||
@enterprise_only
|
||||
def approval_context():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("no_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approval-context: bad cookie: %s", e)
|
||||
raise Unauthorized("no_session") from e
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"user_code": claims.user_code,
|
||||
"csrf_token": claims.csrf_token,
|
||||
"expires_at": claims.expires_at.isoformat(),
|
||||
}
|
||||
), 200
|
||||
|
||||
|
||||
@bp.route("/oauth/device/approve-external", methods=["POST"])
|
||||
@enterprise_only
|
||||
def approve_external():
|
||||
token = request.cookies.get(APPROVAL_GRANT_COOKIE_NAME)
|
||||
if not token:
|
||||
raise Unauthorized("invalid_session")
|
||||
|
||||
keyset = jws.KeySet.from_shared_secret()
|
||||
try:
|
||||
claims: ApprovalGrantClaims = verify_approval_grant(keyset, token)
|
||||
except jws.VerifyError as e:
|
||||
logger.warning("approve-external: bad cookie: %s", e)
|
||||
raise Unauthorized("invalid_session") from e
|
||||
|
||||
enforce(LIMIT_APPROVE_EXT_PER_EMAIL, key=f"subject:{claims.subject_email}")
|
||||
|
||||
csrf_header = request.headers.get("X-CSRF-Token", "")
|
||||
if not csrf_header or csrf_header != claims.csrf_token:
|
||||
raise Forbidden("csrf_mismatch")
|
||||
|
||||
data = request.get_json(silent=True) or {}
|
||||
body_user_code = (data.get("user_code") or "").strip().upper()
|
||||
if body_user_code != claims.user_code:
|
||||
raise BadRequest("user_code_mismatch")
|
||||
|
||||
store = DeviceFlowRedis(redis_client)
|
||||
found = store.load_by_user_code(claims.user_code)
|
||||
if found is None:
|
||||
raise NotFound("user_code_not_pending")
|
||||
device_code, state = found
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise Conflict("user_code_not_pending")
|
||||
|
||||
if not consume_approval_grant_nonce(redis_client, claims.nonce):
|
||||
raise Unauthorized("session_already_consumed")
|
||||
|
||||
ttl_days = oauth_ttl_days(tenant_id=None)
|
||||
mint = mint_oauth_token(
|
||||
db.session,
|
||||
redis_client,
|
||||
subject_email=claims.subject_email,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
account_id=None,
|
||||
client_id=state.client_id,
|
||||
device_label=state.device_label,
|
||||
prefix=PREFIX_OAUTH_EXTERNAL_SSO,
|
||||
ttl_days=ttl_days,
|
||||
)
|
||||
|
||||
poll_payload = {
|
||||
"token": mint.token,
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"account": None,
|
||||
"workspaces": [],
|
||||
"default_workspace_id": None,
|
||||
"token_id": str(mint.token_id),
|
||||
}
|
||||
|
||||
try:
|
||||
store.approve(
|
||||
device_code,
|
||||
subject_email=claims.subject_email,
|
||||
account_id=None,
|
||||
subject_issuer=claims.subject_issuer,
|
||||
minted_token=mint.token,
|
||||
token_id=str(mint.token_id),
|
||||
poll_payload=poll_payload,
|
||||
)
|
||||
except (StateNotFoundError, InvalidTransitionError) as e:
|
||||
logger.exception("approve-external: state transition raced")
|
||||
raise Conflict("state_lost") from e
|
||||
|
||||
_emit_approve_external_audit(state, claims, mint)
|
||||
|
||||
resp = make_response(jsonify({"status": "approved"}), 200)
|
||||
resp.set_cookie(**approval_grant_cleared_cookie_kwargs())
|
||||
return resp
|
||||
|
||||
|
||||
def _emit_approve_external_audit(state, claims, mint) -> None:
|
||||
logger.warning(
|
||||
"audit: oauth.device_flow_approved subject_type=%s subject_email=%s subject_issuer=%s token_id=%s",
|
||||
SubjectType.EXTERNAL_SSO,
|
||||
claims.subject_email,
|
||||
claims.subject_issuer,
|
||||
mint.token_id,
|
||||
extra={
|
||||
"audit": True,
|
||||
"event": "oauth.device_flow_approved",
|
||||
"subject_type": SubjectType.EXTERNAL_SSO,
|
||||
"subject_email": claims.subject_email,
|
||||
"subject_issuer": claims.subject_issuer,
|
||||
"token_id": str(mint.token_id),
|
||||
"client_id": state.client_id,
|
||||
"device_label": state.device_label,
|
||||
"scopes": ["apps:run"],
|
||||
"expires_at": mint.expires_at.isoformat(),
|
||||
},
|
||||
)
|
||||
89
api/controllers/openapi/workspaces.py
Normal file
89
api/controllers/openapi/workspaces.py
Normal file
@ -0,0 +1,89 @@
|
||||
"""User-scoped workspace reads under /openapi/v1/workspaces. Bearer-authed
|
||||
counterparts to the cookie-authed /console/api/workspaces endpoints.
|
||||
|
||||
Account bearers (dfoa_) see every tenant they're a member of. External
|
||||
SSO bearers (dfoe_) have no account_id and so see an empty list — that
|
||||
matches /openapi/v1/account.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import starmap
|
||||
|
||||
from flask import g
|
||||
from flask_restx import Resource
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.openapi import openapi_ns
|
||||
from extensions.ext_database import db
|
||||
from libs.oauth_bearer import (
|
||||
ACCEPT_USER_ANY,
|
||||
SubjectType,
|
||||
validate_bearer,
|
||||
)
|
||||
from models import Tenant, TenantAccountJoin
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces")
|
||||
class WorkspacesApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self):
|
||||
ctx = g.auth_ctx
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
return {"workspaces": []}, 200
|
||||
|
||||
rows = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(TenantAccountJoin.account_id == str(ctx.account_id))
|
||||
.order_by(Tenant.created_at.asc())
|
||||
).all()
|
||||
|
||||
return {"workspaces": list(starmap(_workspace_summary, rows))}, 200
|
||||
|
||||
|
||||
@openapi_ns.route("/workspaces/<string:workspace_id>")
|
||||
class WorkspaceByIdApi(Resource):
|
||||
@validate_bearer(accept=ACCEPT_USER_ANY)
|
||||
def get(self, workspace_id: str):
|
||||
ctx = g.auth_ctx
|
||||
# External SSO + missing account → never a member of anything; 404.
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or not ctx.account_id:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
row = db.session.execute(
|
||||
select(Tenant, TenantAccountJoin)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.tenant_id == Tenant.id)
|
||||
.where(
|
||||
Tenant.id == workspace_id,
|
||||
TenantAccountJoin.account_id == str(ctx.account_id),
|
||||
)
|
||||
).first()
|
||||
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
|
||||
if row is None:
|
||||
raise NotFound("workspace not found")
|
||||
|
||||
tenant, membership = row
|
||||
return _workspace_detail(tenant, membership), 200
|
||||
|
||||
|
||||
def _workspace_summary(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
}
|
||||
|
||||
|
||||
def _workspace_detail(tenant: Tenant, membership: TenantAccountJoin) -> dict:
|
||||
return {
|
||||
"id": str(tenant.id),
|
||||
"name": tenant.name,
|
||||
"role": getattr(membership, "role", ""),
|
||||
"status": tenant.status,
|
||||
"current": getattr(membership, "current", False),
|
||||
"created_at": tenant.created_at.isoformat() if tenant.created_at else None,
|
||||
}
|
||||
@ -685,6 +685,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
match invoke_from:
|
||||
case InvokeFrom.SERVICE_API:
|
||||
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
|
||||
case InvokeFrom.OPENAPI:
|
||||
created_from = WorkflowAppLogCreatedFrom.OPENAPI
|
||||
case InvokeFrom.EXPLORE:
|
||||
created_from = WorkflowAppLogCreatedFrom.INSTALLED_APP
|
||||
case InvokeFrom.WEB_APP:
|
||||
|
||||
@ -24,6 +24,7 @@ class UserFrom(StrEnum):
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
OPENAPI = "openapi"
|
||||
WEB_APP = "web-app"
|
||||
TRIGGER = "trigger"
|
||||
EXPLORE = "explore"
|
||||
|
||||
@ -8,6 +8,8 @@ AUTHENTICATED_HEADERS: tuple[str, ...] = (*SERVICE_API_HEADERS, HEADER_NAME_CSRF
|
||||
FILES_HEADERS: tuple[str, ...] = (*BASE_CORS_HEADERS, HEADER_NAME_CSRF_TOKEN)
|
||||
EMBED_HEADERS: tuple[str, ...] = ("Content-Type", HEADER_NAME_APP_CODE)
|
||||
EXPOSED_HEADERS: tuple[str, ...] = ("X-Version", "X-Env", "X-Trace-Id")
|
||||
OPENAPI_HEADERS: tuple[str, ...] = ("Authorization", "Content-Type", HEADER_NAME_CSRF_TOKEN)
|
||||
OPENAPI_MAX_AGE_SECONDS: int = 600
|
||||
|
||||
|
||||
def _apply_cors_once(bp, /, **cors_kwargs):
|
||||
@ -29,6 +31,7 @@ def init_app(app: DifyApp):
|
||||
from controllers.files import bp as files_bp
|
||||
from controllers.inner_api import bp as inner_api_bp
|
||||
from controllers.mcp import bp as mcp_bp
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.trigger import bp as trigger_bp
|
||||
from controllers.web import bp as web_bp
|
||||
@ -41,6 +44,22 @@ def init_app(app: DifyApp):
|
||||
)
|
||||
app.register_blueprint(service_api_bp)
|
||||
|
||||
# User-scoped programmatic API. Default empty allowlist = same-origin
|
||||
# only; expand via OPENAPI_CORS_ALLOW_ORIGINS for third-party
|
||||
# integrations. supports_credentials so cookie-authed approve/deny
|
||||
# work; cross-origin OPTIONS without an allowed origin will fail
|
||||
# the same as on the console blueprint.
|
||||
_apply_cors_once(
|
||||
openapi_bp,
|
||||
resources={r"/*": {"origins": dify_config.OPENAPI_CORS_ALLOW_ORIGINS}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=list(EXPOSED_HEADERS),
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
app.register_blueprint(openapi_bp)
|
||||
|
||||
_apply_cors_once(
|
||||
web_bp,
|
||||
resources={
|
||||
|
||||
@ -222,6 +222,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.clean_workflow_runs_task.clean_workflow_runs_task",
|
||||
"schedule": crontab(minute="0", hour="0"),
|
||||
}
|
||||
if dify_config.ENABLE_CLEAN_OAUTH_ACCESS_TOKENS_TASK:
|
||||
imports.append("schedule.clean_oauth_access_tokens_task")
|
||||
beat_schedule["clean_oauth_access_tokens_task"] = {
|
||||
"task": "schedule.clean_oauth_access_tokens_task.clean_oauth_access_tokens_task",
|
||||
"schedule": crontab(minute="0", hour="5", day_of_month=f"*/{day}"),
|
||||
}
|
||||
if dify_config.ENABLE_WORKFLOW_SCHEDULE_POLLER_TASK:
|
||||
imports.append("schedule.workflow_schedule_task")
|
||||
beat_schedule["workflow_schedule_task"] = {
|
||||
|
||||
@ -12,7 +12,7 @@ from constants import HEADER_NAME_APP_CODE
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token, extract_webapp_passport
|
||||
from libs.token import extract_access_token, extract_console_cookie_token, extract_webapp_passport
|
||||
from models import Account, Tenant, TenantAccountJoin
|
||||
from models.model import AppMCPServer, EndUser
|
||||
from services.account_service import AccountService
|
||||
@ -84,6 +84,24 @@ def load_user_from_request(request_from_flask_login: Request) -> LoginUser | Non
|
||||
|
||||
logged_in_account = AccountService.load_logged_in_account(account_id=user_id)
|
||||
return logged_in_account
|
||||
elif request.blueprint == "openapi":
|
||||
# Account-branch device-flow approval routes (approve / deny /
|
||||
# approval-context) sit under @login_required and authenticate via
|
||||
# the console session cookie. Cookie-only on purpose — bearer
|
||||
# tokens (dfoa_/dfoe_) live on the Authorization header and are
|
||||
# validated by AppPipeline, not flask-login.
|
||||
cookie_token = extract_console_cookie_token(request)
|
||||
if not cookie_token:
|
||||
return None
|
||||
try:
|
||||
decoded = PassportService().verify(cookie_token)
|
||||
except Exception:
|
||||
return None
|
||||
user_id = decoded.get("user_id")
|
||||
source = decoded.get("token_source")
|
||||
if source or not user_id:
|
||||
return None
|
||||
return AccountService.load_logged_in_account(account_id=user_id)
|
||||
elif request.blueprint == "web":
|
||||
app_code = request.headers.get(HEADER_NAME_APP_CODE)
|
||||
webapp_token = extract_webapp_passport(app_code, request) if app_code else None
|
||||
|
||||
23
api/extensions/ext_oauth_bearer.py
Normal file
23
api/extensions/ext_oauth_bearer.py
Normal file
@ -0,0 +1,23 @@
|
||||
"""Bind the bearer authenticator at startup. Must run after ext_database
|
||||
and ext_redis (needs both factories).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.oauth_bearer import build_and_bind
|
||||
|
||||
|
||||
def is_enabled() -> bool:
|
||||
return dify_config.ENABLE_OAUTH_BEARER
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
# scoped_session isn't a context manager; request teardown closes it.
|
||||
def session_factory():
|
||||
return db.session
|
||||
|
||||
build_and_bind(session_factory=session_factory, redis_client=redis_client)
|
||||
196
api/libs/device_flow_security.py
Normal file
196
api/libs/device_flow_security.py
Normal file
@ -0,0 +1,196 @@
|
||||
"""Device-flow security primitives: enterprise_only gate, approval-grant
|
||||
cookie mint/verify/consume, and anti-framing headers.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from functools import wraps
|
||||
|
||||
from flask import Blueprint
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from libs import jws
|
||||
from libs.token import is_secure
|
||||
from services.feature_service import FeatureService, LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# enterprise_only decorator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
# Fail-closed: any non-EE-active status (default NONE on CE, plus INACTIVE / EXPIRED / LOST)
|
||||
# is denied. Future LicenseStatus values default to denial unless explicitly admitted.
|
||||
_EE_ENABLED_STATUSES = {LicenseStatus.ACTIVE, LicenseStatus.EXPIRING}
|
||||
|
||||
|
||||
def enterprise_only[**P, R](view: Callable[P, R]) -> Callable[P, R]:
|
||||
"""404 on CE, passthrough on EE. Apply before rate-limit so CE
|
||||
responses don't consume the bucket.
|
||||
"""
|
||||
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
settings = FeatureService.get_system_features()
|
||||
if settings.license.status not in _EE_ENABLED_STATUSES:
|
||||
raise NotFound()
|
||||
return view(*args, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# approval_grant cookie
|
||||
# ============================================================================
|
||||
|
||||
|
||||
APPROVAL_GRANT_COOKIE_NAME = "device_approval_grant"
|
||||
APPROVAL_GRANT_COOKIE_PATH = "/openapi/v1/oauth/device"
|
||||
APPROVAL_GRANT_COOKIE_TTL_SECONDS = 300 # 5 min
|
||||
NONCE_TTL_SECONDS = 600 # 2x cookie TTL — defeats clock-skew late replay
|
||||
NONCE_KEY_FMT = "device_approval_grant_nonce:{nonce}"
|
||||
SSO_ASSERTION_NONCE_KEY_FMT = "sso_assertion_nonce:{nonce}"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ApprovalGrantClaims:
|
||||
subject_email: str
|
||||
subject_issuer: str
|
||||
user_code: str
|
||||
nonce: str
|
||||
csrf_token: str
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
def mint_approval_grant(
|
||||
*,
|
||||
keyset: jws.KeySet,
|
||||
iss: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str,
|
||||
user_code: str,
|
||||
) -> tuple[str, ApprovalGrantClaims]:
|
||||
"""Use ``approval_grant_cookie_kwargs`` to set the cookie — single
|
||||
source of truth for Path/HttpOnly/Secure/SameSite.
|
||||
"""
|
||||
now = datetime.now(UTC)
|
||||
exp = now + timedelta(seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
nonce = _random_opaque()
|
||||
csrf_token = _random_opaque()
|
||||
|
||||
payload = {
|
||||
"iss": iss,
|
||||
"subject_email": subject_email,
|
||||
"subject_issuer": subject_issuer,
|
||||
"user_code": user_code,
|
||||
"nonce": nonce,
|
||||
"csrf_token": csrf_token,
|
||||
}
|
||||
token = jws.sign(keyset, payload, aud=jws.AUD_APPROVAL_GRANT, ttl_seconds=APPROVAL_GRANT_COOKIE_TTL_SECONDS)
|
||||
|
||||
return token, ApprovalGrantClaims(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
user_code=user_code,
|
||||
nonce=nonce,
|
||||
csrf_token=csrf_token,
|
||||
expires_at=exp,
|
||||
)
|
||||
|
||||
|
||||
def verify_approval_grant(keyset: jws.KeySet, token: str) -> ApprovalGrantClaims:
|
||||
"""Sig + aud + exp only — nonce consumption is the caller's job."""
|
||||
data = jws.verify(keyset, token, expected_aud=jws.AUD_APPROVAL_GRANT)
|
||||
return ApprovalGrantClaims(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
user_code=data["user_code"],
|
||||
nonce=data["nonce"],
|
||||
csrf_token=data["csrf_token"],
|
||||
expires_at=datetime.fromtimestamp(data["exp"], tz=UTC),
|
||||
)
|
||||
|
||||
|
||||
def consume_approval_grant_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def consume_sso_assertion_nonce(redis_client, nonce: str) -> bool:
|
||||
if not nonce:
|
||||
return False
|
||||
return bool(
|
||||
redis_client.set(
|
||||
SSO_ASSERTION_NONCE_KEY_FMT.format(nonce=nonce),
|
||||
"1",
|
||||
nx=True,
|
||||
ex=NONCE_TTL_SECONDS,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def approval_grant_cookie_kwargs(value: str) -> dict:
|
||||
"""``secure`` follows is_secure() so HTTP-only deployments don't
|
||||
silently drop the cookie.
|
||||
"""
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": value,
|
||||
"max_age": APPROVAL_GRANT_COOKIE_TTL_SECONDS,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def approval_grant_cleared_cookie_kwargs() -> dict:
|
||||
return {
|
||||
"key": APPROVAL_GRANT_COOKIE_NAME,
|
||||
"value": "",
|
||||
"max_age": 0,
|
||||
"path": APPROVAL_GRANT_COOKIE_PATH,
|
||||
"secure": is_secure(),
|
||||
"httponly": True,
|
||||
"samesite": "Lax",
|
||||
}
|
||||
|
||||
|
||||
def _random_opaque() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Anti-framing headers
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_ANTI_FRAMING_HEADERS = {
|
||||
"X-Frame-Options": "DENY",
|
||||
"Content-Security-Policy": "frame-ancestors 'none'",
|
||||
}
|
||||
|
||||
|
||||
def attach_anti_framing(bp: Blueprint) -> None:
|
||||
"""X-Frame-Options + CSP on every response from ``bp`` (CI invariant #4)."""
|
||||
|
||||
@bp.after_request
|
||||
def _apply_headers(response): # pyright: ignore[reportUnusedFunction]
|
||||
for name, value in _ANTI_FRAMING_HEADERS.items():
|
||||
response.headers.setdefault(name, value)
|
||||
return response
|
||||
@ -542,3 +542,18 @@ class RateLimiter:
|
||||
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
def seconds_until_available(self, email: str) -> int:
|
||||
"""Seconds until the oldest in-window entry expires, freeing a slot.
|
||||
|
||||
Defensive floor of 1 second. Caller should only invoke this after
|
||||
is_rate_limited() returned True.
|
||||
"""
|
||||
key = self._get_key(email)
|
||||
oldest = cast(Any, self._redis_client).zrange(key, 0, 0, withscores=True)
|
||||
if not oldest:
|
||||
return 1
|
||||
_member, score = oldest[0]
|
||||
free_at = int(score) + self.time_window
|
||||
remaining = free_at - int(time.time())
|
||||
return max(remaining, 1)
|
||||
|
||||
108
api/libs/jws.py
Normal file
108
api/libs/jws.py
Normal file
@ -0,0 +1,108 @@
|
||||
"""HS256 compact JWS keyed on the shared Dify SECRET_KEY. Used by the SSO
|
||||
state envelope, external subject assertion, and approval-grant cookie —
|
||||
all three share one key-set so api ↔ enterprise can verify each other.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
AUD_STATE_ENVELOPE = "api.sso.state_envelope"
|
||||
AUD_EXT_SUBJECT_ASSERTION = "api.device_flow.external_subject_assertion"
|
||||
AUD_APPROVAL_GRANT = "api.device_flow.approval_grant"
|
||||
|
||||
ACTIVE_KID_V1 = "dify-shared-v1"
|
||||
|
||||
|
||||
class KeySetError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class KeySet:
|
||||
"""``from_entries`` reserves multi-kid construction for rotation slots."""
|
||||
|
||||
def __init__(self, entries: dict[str, bytes], active_kid: str) -> None:
|
||||
if active_kid not in entries:
|
||||
raise KeySetError(f"active kid {active_kid!r} missing from key-set")
|
||||
if not entries[active_kid]:
|
||||
raise KeySetError(f"active kid {active_kid!r} has empty secret")
|
||||
self._entries: dict[str, bytes] = {k: bytes(v) for k, v in entries.items()}
|
||||
self._active_kid = active_kid
|
||||
|
||||
@classmethod
|
||||
def from_shared_secret(cls) -> KeySet:
|
||||
secret = dify_config.SECRET_KEY
|
||||
if not secret:
|
||||
raise KeySetError("dify_config.SECRET_KEY is empty; cannot build key-set")
|
||||
return cls({ACTIVE_KID_V1: secret.encode("utf-8")}, ACTIVE_KID_V1)
|
||||
|
||||
@classmethod
|
||||
def from_entries(cls, entries: dict[str, bytes], active_kid: str) -> KeySet:
|
||||
return cls(entries, active_kid)
|
||||
|
||||
@property
|
||||
def active_kid(self) -> str:
|
||||
return self._active_kid
|
||||
|
||||
def lookup(self, kid: str) -> bytes | None:
|
||||
return self._entries.get(kid)
|
||||
|
||||
|
||||
def sign(keyset: KeySet, payload: dict, aud: str, ttl_seconds: int) -> str:
|
||||
"""``iat`` + ``exp`` are injected here; callers must not set them."""
|
||||
if "aud" in payload or "iat" in payload or "exp" in payload:
|
||||
raise ValueError("reserved claim present in payload (aud/iat/exp)")
|
||||
if ttl_seconds <= 0:
|
||||
raise ValueError("ttl_seconds must be positive")
|
||||
|
||||
kid = keyset.active_kid
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise KeySetError(f"active kid {kid!r} lookup miss")
|
||||
|
||||
iat = datetime.now(UTC)
|
||||
exp = iat + timedelta(seconds=ttl_seconds)
|
||||
claims = {**payload, "aud": aud, "iat": iat, "exp": exp}
|
||||
return jwt.encode(
|
||||
claims,
|
||||
secret,
|
||||
algorithm="HS256",
|
||||
headers={"kid": kid, "typ": "JWT"},
|
||||
)
|
||||
|
||||
|
||||
class VerifyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def verify(keyset: KeySet, token: str, expected_aud: str) -> dict:
|
||||
"""Unknown kid is rejected — never fall back to the active kid, since
|
||||
a past kid value would otherwise be forgeable by anyone who saw it.
|
||||
"""
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode header: {e}") from e
|
||||
kid = header.get("kid")
|
||||
if not kid:
|
||||
raise VerifyError("no kid in header")
|
||||
secret = keyset.lookup(kid)
|
||||
if secret is None:
|
||||
raise VerifyError(f"unknown kid {kid!r}")
|
||||
try:
|
||||
return jwt.decode(
|
||||
token,
|
||||
secret,
|
||||
algorithms=["HS256"],
|
||||
audience=expected_aud,
|
||||
)
|
||||
except jwt.ExpiredSignatureError as e:
|
||||
raise VerifyError("token expired") from e
|
||||
except jwt.InvalidAudienceError as e:
|
||||
raise VerifyError("aud mismatch") from e
|
||||
except jwt.PyJWTError as e:
|
||||
raise VerifyError(f"decode: {e}") from e
|
||||
608
api/libs/oauth_bearer.py
Normal file
608
api/libs/oauth_bearer.py
Normal file
@ -0,0 +1,608 @@
|
||||
"""OAuth bearer primitives.
|
||||
|
||||
To add a token kind: write a Resolver, add a SubjectType + Accepts member,
|
||||
append a TokenKind to build_registry, and update _SUBJECT_TO_ACCEPT.
|
||||
Authenticator + validate_bearer stay untouched.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import UTC, datetime
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import Literal, ParamSpec, Protocol, TypeVar
|
||||
|
||||
from flask import g, request
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, ServiceUnavailable, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.rate_limit import enforce_bearer_rate_limit
|
||||
from models import Account, OAuthAccessToken, TenantAccountJoin
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Contract — types, enums, protocols
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class SubjectType(StrEnum):
|
||||
ACCOUNT = "account"
|
||||
EXTERNAL_SSO = "external_sso"
|
||||
|
||||
|
||||
class Scope(StrEnum):
|
||||
"""Catalog of bearer scopes recognised by the openapi surface.
|
||||
|
||||
`FULL` is the catch-all carried by `dfoa_` account tokens — it satisfies
|
||||
any per-route `require_scope`. `dfoe_` tokens carry the per-feature scopes
|
||||
(`APPS_RUN`, `APPS_READ_PERMITTED`).
|
||||
"""
|
||||
|
||||
FULL = "full"
|
||||
APPS_READ = "apps:read"
|
||||
APPS_READ_PERMITTED = "apps:read:permitted"
|
||||
APPS_RUN = "apps:run"
|
||||
|
||||
|
||||
class Accepts(StrEnum):
|
||||
"""Subject types a route is willing to accept as caller."""
|
||||
|
||||
USER_ACCOUNT = "user_account"
|
||||
USER_EXT_SSO = "user_ext_sso"
|
||||
|
||||
|
||||
ACCEPT_USER_ANY: frozenset[Accepts] = frozenset({Accepts.USER_ACCOUNT, Accepts.USER_EXT_SSO})
|
||||
ACCEPT_USER_EXT_SSO: frozenset[Accepts] = frozenset({Accepts.USER_EXT_SSO})
|
||||
|
||||
_SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
|
||||
SubjectType.ACCOUNT: Accepts.USER_ACCOUNT,
|
||||
SubjectType.EXTERNAL_SSO: Accepts.USER_EXT_SSO,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class AuthContext:
|
||||
"""Attached to ``g.auth_ctx``. ``scopes`` / ``subject_type`` / ``source``
|
||||
come from the TokenKind, not the DB — corrupt rows can't elevate scope.
|
||||
|
||||
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
|
||||
authenticate time. Per-request mutations write through to Redis via
|
||||
`record_layer0_verdict`; this snapshot is not updated in place (frozen).
|
||||
"""
|
||||
|
||||
subject_type: SubjectType
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
scopes: frozenset[Scope]
|
||||
token_id: uuid.UUID
|
||||
source: str
|
||||
expires_at: datetime | None
|
||||
token_hash: str
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class ResolvedRow:
|
||||
subject_email: str | None
|
||||
subject_issuer: str | None
|
||||
account_id: uuid.UUID | None
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime | None
|
||||
verified_tenants: dict[str, bool] = field(default_factory=dict)
|
||||
|
||||
def to_cache(self) -> dict:
|
||||
return {
|
||||
"subject_email": self.subject_email,
|
||||
"subject_issuer": self.subject_issuer,
|
||||
"account_id": str(self.account_id) if self.account_id else None,
|
||||
"token_id": str(self.token_id),
|
||||
"expires_at": self.expires_at.isoformat() if self.expires_at else None,
|
||||
"verified_tenants": dict(self.verified_tenants),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_cache(cls, data: dict) -> ResolvedRow:
|
||||
return cls(
|
||||
subject_email=data["subject_email"],
|
||||
subject_issuer=data["subject_issuer"],
|
||||
account_id=uuid.UUID(data["account_id"]) if data["account_id"] else None,
|
||||
token_id=uuid.UUID(data["token_id"]),
|
||||
expires_at=datetime.fromisoformat(data["expires_at"]) if data["expires_at"] else None,
|
||||
verified_tenants=_coerce_verified_tenants(data.get("verified_tenants")),
|
||||
)
|
||||
|
||||
|
||||
def _coerce_verified_tenants(raw: object) -> dict[str, bool]:
|
||||
"""Tolerate legacy entries that stored 'ok'/'denied' string verdicts.
|
||||
|
||||
TODO(post-v1.0): remove once the AuthContext cache TTL has fully cycled
|
||||
on all live deployments (60s TTL → safe to drop one release after rollout).
|
||||
"""
|
||||
if not isinstance(raw, dict):
|
||||
return {}
|
||||
out: dict[str, bool] = {}
|
||||
for k, v in raw.items():
|
||||
if isinstance(v, bool):
|
||||
out[k] = v
|
||||
elif v == "ok":
|
||||
out[k] = True
|
||||
elif v == "denied":
|
||||
out[k] = False
|
||||
return out
|
||||
|
||||
|
||||
class Resolver(Protocol):
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None: # pragma: no cover - contract
|
||||
...
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class TokenKind:
|
||||
prefix: str
|
||||
subject_type: SubjectType
|
||||
scopes: frozenset[Scope]
|
||||
source: str
|
||||
resolver: Resolver
|
||||
|
||||
def matches(self, token: str) -> bool:
|
||||
return token.startswith(self.prefix)
|
||||
|
||||
|
||||
class InvalidBearerError(Exception):
|
||||
"""Token missing, unknown prefix, or no live row."""
|
||||
|
||||
|
||||
class TokenExpiredError(Exception):
|
||||
"""Hard-expire bookkeeping is the resolver's job before raising."""
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Registry
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TokenKindRegistry:
|
||||
def __init__(self, kinds: Iterable[TokenKind]) -> None:
|
||||
self._kinds: tuple[TokenKind, ...] = tuple(kinds)
|
||||
prefixes = [k.prefix for k in self._kinds]
|
||||
if len(set(prefixes)) != len(prefixes):
|
||||
raise ValueError(f"duplicate prefix in registry: {prefixes}")
|
||||
|
||||
def find(self, token: str) -> TokenKind | None:
|
||||
for k in self._kinds:
|
||||
if k.matches(token):
|
||||
return k
|
||||
return None
|
||||
|
||||
def kinds(self) -> tuple[TokenKind, ...]:
|
||||
return self._kinds
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Authenticator
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
class BearerAuthenticator:
|
||||
def __init__(self, registry: TokenKindRegistry) -> None:
|
||||
self._registry = registry
|
||||
|
||||
@property
|
||||
def registry(self) -> TokenKindRegistry:
|
||||
return self._registry
|
||||
|
||||
def authenticate(self, token: str) -> AuthContext:
|
||||
"""Identity + per-token rate limit (single source).
|
||||
|
||||
Both the openapi pipeline (`BearerCheck`) and the decorator
|
||||
(`validate_bearer`) call this — rate-limit fires exactly once per
|
||||
request regardless of which path hosts the route.
|
||||
"""
|
||||
kind = self._registry.find(token)
|
||||
if kind is None:
|
||||
raise InvalidBearerError("unknown token prefix")
|
||||
token_hash = sha256_hex(token)
|
||||
row = kind.resolver.resolve(token_hash)
|
||||
if row is None:
|
||||
raise InvalidBearerError("token unknown or revoked")
|
||||
enforce_bearer_rate_limit(token_hash)
|
||||
return AuthContext(
|
||||
subject_type=kind.subject_type,
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=row.account_id,
|
||||
scopes=kind.scopes,
|
||||
token_id=row.token_id,
|
||||
source=kind.source,
|
||||
expires_at=row.expires_at,
|
||||
token_hash=token_hash,
|
||||
verified_tenants=dict(row.verified_tenants),
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# OAuth access token resolver (PAT resolver would be a sibling class)
|
||||
# ============================================================================
|
||||
|
||||
TOKEN_CACHE_KEY_FMT = "auth:token:{hash}"
|
||||
POSITIVE_TTL_SECONDS = 60
|
||||
NEGATIVE_TTL_SECONDS = 10
|
||||
AUDIT_OAUTH_EXPIRED = "oauth.token_expired"
|
||||
|
||||
ScopeVariant = Literal["account", "external_sso"]
|
||||
|
||||
|
||||
class OAuthAccessTokenResolver:
|
||||
"""``.for_account()`` / ``.for_external_sso()`` are variant-scoped views
|
||||
sharing DB + cache plumbing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory,
|
||||
redis_client,
|
||||
positive_ttl: int = POSITIVE_TTL_SECONDS,
|
||||
negative_ttl: int = NEGATIVE_TTL_SECONDS,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._redis = redis_client
|
||||
self._positive_ttl = positive_ttl
|
||||
self._negative_ttl = negative_ttl
|
||||
|
||||
def for_account(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="account")
|
||||
|
||||
def for_external_sso(self) -> Resolver:
|
||||
return _VariantResolver(self, variant="external_sso")
|
||||
|
||||
def _cache_key(self, token_hash: str) -> str:
|
||||
return TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
|
||||
def cache_get(self, token_hash: str) -> ResolvedRow | None | Literal["invalid"]:
|
||||
raw = self._redis.get(self._cache_key(token_hash))
|
||||
if raw is None:
|
||||
return None
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return "invalid"
|
||||
try:
|
||||
return ResolvedRow.from_cache(json.loads(text))
|
||||
except (ValueError, KeyError):
|
||||
logger.warning("auth:token cache entry malformed; treating as miss")
|
||||
return None
|
||||
|
||||
def cache_set_positive(self, token_hash: str, row: ResolvedRow) -> None:
|
||||
self._redis.setex(
|
||||
self._cache_key(token_hash),
|
||||
self._positive_ttl,
|
||||
json.dumps(row.to_cache()),
|
||||
)
|
||||
|
||||
def cache_set_negative(self, token_hash: str) -> None:
|
||||
self._redis.setex(self._cache_key(token_hash), self._negative_ttl, "invalid")
|
||||
|
||||
def hard_expire(self, session: Session, row_id: uuid.UUID | str, token_hash: str) -> None:
|
||||
"""Atomic CAS — only the worker that flips revoked_at emits audit;
|
||||
replays are idempotent.
|
||||
"""
|
||||
stmt = (
|
||||
update(OAuthAccessToken)
|
||||
.where(OAuthAccessToken.id == row_id, OAuthAccessToken.revoked_at.is_(None))
|
||||
.values(revoked_at=datetime.now(UTC), token_hash=None)
|
||||
)
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
if result.rowcount == 1:
|
||||
logger.warning(
|
||||
"audit: %s token_id=%s",
|
||||
AUDIT_OAUTH_EXPIRED,
|
||||
row_id,
|
||||
extra={"audit": True, "token_id": str(row_id)},
|
||||
)
|
||||
self._redis.delete(self._cache_key(token_hash))
|
||||
self.cache_set_negative(token_hash)
|
||||
|
||||
|
||||
class _VariantResolver:
|
||||
def __init__(self, parent: OAuthAccessTokenResolver, variant: ScopeVariant) -> None:
|
||||
self._parent = parent
|
||||
self._variant = variant
|
||||
|
||||
def resolve(self, token_hash: str) -> ResolvedRow | None:
|
||||
cached = self._parent.cache_get(token_hash)
|
||||
if cached == "invalid":
|
||||
return None
|
||||
if cached is not None and not isinstance(cached, str):
|
||||
if not self._matches_variant(cached):
|
||||
return None
|
||||
return cached
|
||||
|
||||
# Flask-SQLAlchemy's scoped_session is request-bound and not a
|
||||
# context manager; use it directly.
|
||||
session = self._parent.session_factory()
|
||||
row = self._load_from_db(session, token_hash)
|
||||
if row is None:
|
||||
self._parent.cache_set_negative(token_hash)
|
||||
return None
|
||||
|
||||
now = datetime.now(UTC)
|
||||
if row.expires_at is not None and row.expires_at <= now:
|
||||
self._parent.hard_expire(session, row.id, token_hash)
|
||||
return None
|
||||
|
||||
if not self._matches_variant_model(row):
|
||||
logger.error(
|
||||
"internal_state_invariant: account_id/prefix mismatch token_id=%s prefix=%s",
|
||||
row.id,
|
||||
row.prefix,
|
||||
)
|
||||
return None
|
||||
|
||||
resolved = ResolvedRow(
|
||||
subject_email=row.subject_email,
|
||||
subject_issuer=row.subject_issuer,
|
||||
account_id=uuid.UUID(str(row.account_id)) if row.account_id else None,
|
||||
token_id=uuid.UUID(str(row.id)),
|
||||
expires_at=row.expires_at,
|
||||
)
|
||||
self._parent.cache_set_positive(token_hash, resolved)
|
||||
return resolved
|
||||
|
||||
def _matches_variant(self, row: ResolvedRow) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account
|
||||
return not has_account
|
||||
|
||||
def _matches_variant_model(self, row: OAuthAccessToken) -> bool:
|
||||
has_account = row.account_id is not None
|
||||
if self._variant == "account":
|
||||
return has_account and row.prefix == "dfoa_"
|
||||
return (not has_account) and row.prefix == "dfoe_"
|
||||
|
||||
def _load_from_db(self, session: Session, token_hash: str) -> OAuthAccessToken | None:
|
||||
return (
|
||||
session.query(OAuthAccessToken)
|
||||
.filter(
|
||||
OAuthAccessToken.token_hash == token_hash,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Layer 0 — workspace membership cache + helper
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def record_layer0_verdict(token_hash: str, tenant_id: str, verdict: bool) -> None:
|
||||
"""Merge a Layer-0 membership verdict into the AuthContext cache entry at
|
||||
`auth:token:{hash}`. No-op if entry missing/expired/invalid — next request
|
||||
rebuilds via authenticate() and re-runs Layer 0.
|
||||
"""
|
||||
cache_key = TOKEN_CACHE_KEY_FMT.format(hash=token_hash)
|
||||
raw = redis_client.get(cache_key)
|
||||
if raw is None:
|
||||
return
|
||||
text = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
if text == "invalid":
|
||||
return
|
||||
try:
|
||||
data = json.loads(text)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
ttl = redis_client.ttl(cache_key)
|
||||
if ttl <= 0:
|
||||
return
|
||||
data.setdefault("verified_tenants", {})[tenant_id] = verdict
|
||||
redis_client.setex(cache_key, ttl, json.dumps(data))
|
||||
|
||||
|
||||
def check_workspace_membership(
|
||||
*,
|
||||
account_id: uuid.UUID | str,
|
||||
tenant_id: str,
|
||||
token_hash: str,
|
||||
cached_verdicts: dict[str, bool],
|
||||
) -> None:
|
||||
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
|
||||
|
||||
Shared by the pipeline step (`WorkspaceMembershipCheck`) and the
|
||||
inline helper (`require_workspace_member`). Caller is responsible for
|
||||
short-circuiting on EE / SSO subjects before invoking — this function
|
||||
runs the membership + active-status checks unconditionally.
|
||||
"""
|
||||
cached = cached_verdicts.get(tenant_id)
|
||||
if cached is True:
|
||||
return
|
||||
if cached is False:
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
join = db.session.execute(
|
||||
select(TenantAccountJoin.id).where(
|
||||
TenantAccountJoin.account_id == account_id,
|
||||
TenantAccountJoin.tenant_id == tenant_id,
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if join is None:
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
status = db.session.execute(select(Account.status).where(Account.id == account_id)).scalar_one_or_none()
|
||||
if status != "active":
|
||||
record_layer0_verdict(token_hash, tenant_id, False)
|
||||
raise Forbidden("workspace_membership_revoked")
|
||||
|
||||
record_layer0_verdict(token_hash, tenant_id, True)
|
||||
|
||||
|
||||
def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
|
||||
"""AuthContext-flavoured wrapper around `check_workspace_membership`.
|
||||
|
||||
No-op on EE (gateway RBAC owns tenant isolation) and for SSO subjects
|
||||
(no `tenant_account_joins` row by definition).
|
||||
"""
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
return
|
||||
if ctx.subject_type != SubjectType.ACCOUNT or ctx.account_id is None:
|
||||
return
|
||||
check_workspace_membership(
|
||||
account_id=ctx.account_id,
|
||||
tenant_id=tenant_id,
|
||||
token_hash=ctx.token_hash,
|
||||
cached_verdicts=ctx.verified_tenants,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Decorator — route-level bearer gate
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_authenticator: BearerAuthenticator | None = None
|
||||
|
||||
|
||||
def bind_authenticator(authenticator: BearerAuthenticator) -> None:
|
||||
global _authenticator
|
||||
_authenticator = authenticator
|
||||
|
||||
|
||||
def get_authenticator() -> BearerAuthenticator:
|
||||
if _authenticator is None:
|
||||
raise RuntimeError("BearerAuthenticator not bound; call bind_authenticator at startup")
|
||||
return _authenticator
|
||||
|
||||
|
||||
def _extract_bearer(req) -> str | None:
|
||||
header = req.headers.get("Authorization", "")
|
||||
scheme, _, value = header.partition(" ")
|
||||
if scheme.lower() != "bearer" or not value:
|
||||
return None
|
||||
return value.strip()
|
||||
|
||||
|
||||
_DP = ParamSpec("_DP")
|
||||
_DR = TypeVar("_DR")
|
||||
|
||||
|
||||
def validate_bearer(*, accept: frozenset[Accepts]) -> Callable[[Callable[_DP, _DR]], Callable[_DP, _DR]]:
|
||||
"""Opt-in: omitting it leaves the route unauthenticated.
|
||||
|
||||
Resolves user-level OAuth bearers (``dfoa_`` / ``dfoe_``). Legacy
|
||||
``app-`` keys belong to ``service_api/wraps.py:validate_app_token``
|
||||
and are rejected here as the wrong auth scheme for this surface.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable[_DP, _DR]) -> Callable[_DP, _DR]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _DP.args, **kwargs: _DP.kwargs) -> _DR:
|
||||
token = _extract_bearer(request)
|
||||
if token is None:
|
||||
raise Unauthorized("missing bearer token")
|
||||
|
||||
if _authenticator is None:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
|
||||
try:
|
||||
ctx = get_authenticator().authenticate(token)
|
||||
except InvalidBearerError as e:
|
||||
raise Unauthorized(str(e))
|
||||
|
||||
if _SUBJECT_TO_ACCEPT[ctx.subject_type] not in accept:
|
||||
raise Forbidden("token subject type not accepted here")
|
||||
|
||||
g.auth_ctx = ctx
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def bearer_feature_required[**P, R](fn: Callable[P, R]) -> Callable[P, R]:
|
||||
"""503 if ENABLE_OAUTH_BEARER is off — minted tokens would be unusable
|
||||
without the authenticator, so fail fast instead of approving silently.
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def inner(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
if not dify_config.ENABLE_OAUTH_BEARER:
|
||||
raise ServiceUnavailable("bearer_auth_disabled: set ENABLE_OAUTH_BEARER=true to enable")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def require_scope(scope: Scope) -> Callable:
|
||||
"""Route-level scope gate — must run AFTER validate_bearer so that
|
||||
g.auth_ctx is set. Raises Forbidden('insufficient_scope: <scope>')
|
||||
when the bearer lacks both the requested scope and `Scope.FULL`.
|
||||
"""
|
||||
|
||||
def wrap(fn: Callable) -> Callable:
|
||||
@wraps(fn)
|
||||
def inner(*args, **kwargs):
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx is None:
|
||||
raise RuntimeError(
|
||||
"require_scope used without validate_bearer; stack @validate_bearer above @require_scope"
|
||||
)
|
||||
if Scope.FULL not in ctx.scopes and scope not in ctx.scopes:
|
||||
raise Forbidden(f"insufficient_scope: {scope}")
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Wiring — called once from the app factory
|
||||
# ============================================================================
|
||||
|
||||
|
||||
def build_registry(session_factory, redis_client) -> TokenKindRegistry:
|
||||
oauth = OAuthAccessTokenResolver(session_factory, redis_client)
|
||||
return TokenKindRegistry(
|
||||
[
|
||||
TokenKind(
|
||||
prefix="dfoa_",
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
source="oauth_account",
|
||||
resolver=oauth.for_account(),
|
||||
),
|
||||
TokenKind(
|
||||
prefix="dfoe_",
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED}),
|
||||
source="oauth_external_sso",
|
||||
resolver=oauth.for_external_sso(),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_and_bind(session_factory, redis_client) -> BearerAuthenticator:
|
||||
registry = build_registry(session_factory, redis_client)
|
||||
auth = BearerAuthenticator(registry)
|
||||
bind_authenticator(auth)
|
||||
return auth
|
||||
140
api/libs/rate_limit.py
Normal file
140
api/libs/rate_limit.py
Normal file
@ -0,0 +1,140 @@
|
||||
"""Typed rate-limit decorator over ``libs.helper.RateLimiter`` (sliding-
|
||||
window Redis ZSET). Apply after auth decorators so scopes can read
|
||||
``g.auth_ctx``. Use :func:`enforce` when the bucket key is computed
|
||||
in-handler. RFC-8628 ``slow_down`` is inline — its response shape isn't
|
||||
generic 429.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import g, jsonify, make_response, request, session
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
|
||||
from configs import dify_config
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
|
||||
|
||||
class RateLimitScope(StrEnum):
|
||||
IP = "ip"
|
||||
SESSION = "session"
|
||||
ACCOUNT = "account"
|
||||
SUBJECT_EMAIL = "subject_email"
|
||||
TOKEN_ID = "token_id"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class RateLimit:
|
||||
limit: int
|
||||
window: timedelta
|
||||
scopes: tuple[RateLimitScope, ...]
|
||||
|
||||
|
||||
LIMIT_DEVICE_CODE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_SSO_INITIATE_PER_IP = RateLimit(60, timedelta(hours=1), (RateLimitScope.IP,))
|
||||
LIMIT_APPROVE_EXT_PER_EMAIL = RateLimit(10, timedelta(hours=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_APPROVE_CONSOLE = RateLimit(10, timedelta(hours=1), (RateLimitScope.SESSION,))
|
||||
LIMIT_LOOKUP_PUBLIC = RateLimit(60, timedelta(minutes=5), (RateLimitScope.IP,))
|
||||
LIMIT_ME_PER_ACCOUNT = RateLimit(60, timedelta(minutes=1), (RateLimitScope.ACCOUNT,))
|
||||
LIMIT_ME_PER_EMAIL = RateLimit(60, timedelta(minutes=1), (RateLimitScope.SUBJECT_EMAIL,))
|
||||
LIMIT_BEARER_PER_TOKEN = RateLimit(
|
||||
limit=dify_config.OPENAPI_RATE_LIMIT_PER_TOKEN,
|
||||
window=timedelta(minutes=1),
|
||||
scopes=(RateLimitScope.TOKEN_ID,), # bucket key composed by caller from sha256(token)
|
||||
)
|
||||
|
||||
|
||||
def _one_key(scope: RateLimitScope) -> str:
|
||||
match scope:
|
||||
case RateLimitScope.IP:
|
||||
return f"ip:{extract_remote_ip(request) or 'unknown'}"
|
||||
case RateLimitScope.SESSION:
|
||||
return f"session:{session.get('_id', 'anon')}"
|
||||
case RateLimitScope.ACCOUNT:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.account_id:
|
||||
return f"account:{ctx.account_id}"
|
||||
return "account:anon"
|
||||
case RateLimitScope.SUBJECT_EMAIL:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.subject_email:
|
||||
return f"subject:{ctx.subject_email}"
|
||||
return "subject:anon"
|
||||
case RateLimitScope.TOKEN_ID:
|
||||
ctx = getattr(g, "auth_ctx", None)
|
||||
if ctx and ctx.token_id:
|
||||
return f"token:{ctx.token_id}"
|
||||
return "token:anon"
|
||||
|
||||
|
||||
def _composite_key(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "|".join(_one_key(s) for s in scopes)
|
||||
|
||||
|
||||
def _limiter_prefix(scopes: tuple[RateLimitScope, ...]) -> str:
|
||||
return "rl:" + "+".join(s.value for s in scopes)
|
||||
|
||||
|
||||
def _build_limiter(spec: RateLimit) -> RateLimiter:
|
||||
return RateLimiter(
|
||||
prefix=_limiter_prefix(spec.scopes),
|
||||
max_attempts=spec.limit,
|
||||
time_window=int(spec.window.total_seconds()),
|
||||
)
|
||||
|
||||
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
||||
|
||||
def rate_limit(spec: RateLimit) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
|
||||
"""Apply after auth decorators that the scopes read from."""
|
||||
limiter = _build_limiter(spec)
|
||||
|
||||
def wrap(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
key = _composite_key(spec.scopes)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
return wrap
|
||||
|
||||
|
||||
def enforce(spec: RateLimit, *, key: str) -> None:
|
||||
"""Imperative form — caller composes the bucket key to match scope
|
||||
semantics (the key is opaque here).
|
||||
"""
|
||||
limiter = _build_limiter(spec)
|
||||
if limiter.is_rate_limited(key):
|
||||
raise TooManyRequests("rate_limited")
|
||||
limiter.increment_rate_limit(key)
|
||||
|
||||
|
||||
def enforce_bearer_rate_limit(token_hash: str) -> None:
|
||||
"""Per-token rate limit on /openapi/v1/* bearer-authed routes.
|
||||
|
||||
Bucket key = ``token:<sha256_hex>`` so the same token shares one
|
||||
bucket across api replicas (Redis-backed sliding window).
|
||||
"""
|
||||
limiter = _build_limiter(LIMIT_BEARER_PER_TOKEN)
|
||||
key = f"token:{token_hash}"
|
||||
if limiter.is_rate_limited(key):
|
||||
retry_after = limiter.seconds_until_available(key)
|
||||
response = make_response(
|
||||
jsonify({"error": "rate_limited", "retry_after_ms": retry_after * 1000}),
|
||||
429,
|
||||
)
|
||||
response.headers["Retry-After"] = str(retry_after)
|
||||
raise TooManyRequests(response=response)
|
||||
limiter.increment_rate_limit(key)
|
||||
@ -72,11 +72,15 @@ def extract_csrf_token_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_CSRF_TOKEN))
|
||||
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
def _try_extract_from_cookie(request: Request) -> str | None:
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
def extract_console_cookie_token(request: Request) -> str | None:
|
||||
"""Cookie-only console session token. Used by /openapi/v1/oauth/device/*
|
||||
approval routes, which must not fall through to the Authorization header
|
||||
(that's where dfoa_/dfoe_ bearers live — they aren't JWTs)."""
|
||||
return request.cookies.get(_real_cookie_name(COOKIE_NAME_ACCESS_TOKEN))
|
||||
|
||||
return _try_extract_from_cookie(request) or _try_extract_from_header(request)
|
||||
|
||||
def extract_access_token(request: Request) -> str | None:
|
||||
return extract_console_cookie_token(request) or _try_extract_from_header(request)
|
||||
|
||||
|
||||
def extract_webapp_access_token(request: Request) -> str | None:
|
||||
|
||||
@ -0,0 +1,104 @@
|
||||
"""add oauth_access_tokens table
|
||||
|
||||
Revision ID: d4a5e1f3c9b7
|
||||
Revises: 227822d22895, b69ca54b9208, 2a3aebbbf4bb
|
||||
Create Date: 2026-04-23 22:00:00.000000
|
||||
|
||||
Merges the three open heads at time of authoring (add_workflow_comments_table,
|
||||
add_chatbot_color_theme, add_app_tracing) into a single parent so the new
|
||||
oauth_access_tokens table sits on a definite linear chain thereafter.
|
||||
|
||||
Table stores user-level OAuth bearer tokens minted via the device-flow grant
|
||||
(difyctl auth login). PAT storage (personal_access_tokens) is a separate
|
||||
table not added in this migration.
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d4a5e1f3c9b7"
|
||||
down_revision = ("227822d22895", "b69ca54b9208", "2a3aebbbf4bb")
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
op.create_table(
|
||||
"oauth_access_tokens",
|
||||
sa.Column(
|
||||
"id",
|
||||
postgresql.UUID(as_uuid=True),
|
||||
server_default=sa.text("gen_random_uuid()"),
|
||||
nullable=False,
|
||||
primary_key=True,
|
||||
),
|
||||
sa.Column("subject_email", sa.Text(), nullable=False),
|
||||
sa.Column("subject_issuer", sa.Text(), nullable=True),
|
||||
sa.Column("account_id", postgresql.UUID(as_uuid=True), nullable=True),
|
||||
sa.Column("client_id", sa.String(length=64), nullable=False),
|
||||
sa.Column("device_label", sa.Text(), nullable=False),
|
||||
sa.Column("prefix", sa.String(length=8), nullable=False),
|
||||
sa.Column("token_hash", sa.String(length=64), nullable=True, unique=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.TIMESTAMP(timezone=True),
|
||||
server_default=sa.text("NOW()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("last_used_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.TIMESTAMP(timezone=True), nullable=False),
|
||||
sa.Column("revoked_at", sa.TIMESTAMP(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["account_id"],
|
||||
["accounts.id"],
|
||||
name="fk_oauth_access_tokens_account_id",
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"idx_oauth_subject_email",
|
||||
"oauth_access_tokens",
|
||||
["subject_email"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_account",
|
||||
"oauth_access_tokens",
|
||||
["account_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL AND account_id IS NOT NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_client",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "client_id"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
op.create_index(
|
||||
"idx_oauth_token_hash",
|
||||
"oauth_access_tokens",
|
||||
["token_hash"],
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
# Partial unique index — rotate-in-place keyed on (subject, client, device).
|
||||
# The app always writes a non-NULL subject_issuer (account flow uses a
|
||||
# sentinel, external-SSO uses the verified IdP issuer); without that the
|
||||
# composite key would never collide because Postgres treats NULLs as
|
||||
# distinct in unique indices.
|
||||
op.create_index(
|
||||
"uq_oauth_active_per_device",
|
||||
"oauth_access_tokens",
|
||||
["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("revoked_at IS NULL"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.drop_index("uq_oauth_active_per_device", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_token_hash", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_client", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_account", table_name="oauth_access_tokens")
|
||||
op.drop_index("idx_oauth_subject_email", table_name="oauth_access_tokens")
|
||||
op.drop_table("oauth_access_tokens")
|
||||
@ -73,7 +73,7 @@ from .model import (
|
||||
TrialApp,
|
||||
UploadFile,
|
||||
)
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||
from .oauth import DatasourceOauthParamConfig, DatasourceProvider, OAuthAccessToken
|
||||
from .provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
@ -177,6 +177,7 @@ __all__ = [
|
||||
"MessageChain",
|
||||
"MessageFeedback",
|
||||
"MessageFile",
|
||||
"OAuthAccessToken",
|
||||
"OperationLog",
|
||||
"PinnedConversation",
|
||||
"Provider",
|
||||
|
||||
@ -84,3 +84,35 @@ class DatasourceOauthTenantParamConfig(TypeBase):
|
||||
onupdate=func.current_timestamp(),
|
||||
init=False,
|
||||
)
|
||||
|
||||
|
||||
class OAuthAccessToken(TypeBase):
|
||||
"""Device-flow bearer. account_id NOT NULL ⇒ dfoa_ (Dify account,
|
||||
subject_issuer = "dify:account" sentinel); account_id NULL +
|
||||
subject_issuer = verified IdP issuer ⇒ dfoe_ (external SSO, EE-only).
|
||||
subject_issuer is non-NULL for all rows the app writes — Postgres
|
||||
treats NULLs as distinct in unique indices, so the partial unique
|
||||
index on (subject_email, subject_issuer, client_id, device_label)
|
||||
WHERE revoked_at IS NULL would otherwise fail to rotate in place.
|
||||
"""
|
||||
|
||||
__tablename__ = "oauth_access_tokens"
|
||||
__table_args__ = (sa.PrimaryKeyConstraint("id", name="oauth_access_tokens_pkey"),)
|
||||
|
||||
id: Mapped[str] = mapped_column(
|
||||
StringUUID, insert_default=lambda: str(uuidv7()), default_factory=lambda: str(uuidv7()), init=False
|
||||
)
|
||||
subject_email: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
client_id: Mapped[str] = mapped_column(sa.String(64), nullable=False)
|
||||
device_label: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
prefix: Mapped[str] = mapped_column(sa.String(8), nullable=False)
|
||||
expires_at: Mapped[datetime] = mapped_column(sa.DateTime(timezone=True), nullable=False)
|
||||
subject_issuer: Mapped[str | None] = mapped_column(sa.Text, nullable=True, default=None)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
token_hash: Mapped[str | None] = mapped_column(sa.String(64), nullable=True, default=None)
|
||||
last_used_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
revoked_at: Mapped[datetime | None] = mapped_column(sa.DateTime(timezone=True), nullable=True, default=None)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime(timezone=True), nullable=False, server_default=func.now(), init=False
|
||||
)
|
||||
|
||||
@ -1206,6 +1206,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
SERVICE_API = "service-api"
|
||||
WEB_APP = "web-app"
|
||||
INSTALLED_APP = "installed-app"
|
||||
OPENAPI = "openapi"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
|
||||
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
54
api/schedule/clean_oauth_access_tokens_task.py
Normal file
@ -0,0 +1,54 @@
|
||||
"""DELETE oauth_access_tokens past retention. Revocation is UPDATE
|
||||
(token_id stays for audits) so rows accumulate across re-logins, and
|
||||
expired-but-never-presented rows have no hard-expire trigger — both get
|
||||
pruned here. Spec: docs/specs/v1.0/server/tokens.md §Hard-expire.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import click
|
||||
from sqlalchemy import delete, or_, select
|
||||
|
||||
import app
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DELETE_BATCH_SIZE = 500
|
||||
|
||||
|
||||
@app.celery.task(queue="retention")
|
||||
def clean_oauth_access_tokens_task():
|
||||
click.echo(click.style("Start clean oauth_access_tokens.", fg="green"))
|
||||
retention_days = int(dify_config.OAUTH_ACCESS_TOKEN_RETENTION_DAYS)
|
||||
cutoff = datetime.now(UTC) - timedelta(days=retention_days)
|
||||
start_at = time.perf_counter()
|
||||
|
||||
candidates = or_(
|
||||
OAuthAccessToken.revoked_at < cutoff,
|
||||
# Zombies: expired but never re-presented, so middleware never flipped them.
|
||||
(OAuthAccessToken.revoked_at.is_(None)) & (OAuthAccessToken.expires_at < cutoff),
|
||||
)
|
||||
|
||||
total = 0
|
||||
while True:
|
||||
ids = db.session.scalars(select(OAuthAccessToken.id).where(candidates).limit(DELETE_BATCH_SIZE)).all()
|
||||
if not ids:
|
||||
break
|
||||
db.session.execute(delete(OAuthAccessToken).where(OAuthAccessToken.id.in_(ids)))
|
||||
db.session.commit()
|
||||
total += len(ids)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
f"Cleaned {total} oauth_access_tokens rows older than {retention_days}d in {end_at - start_at:.2f}s",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -37,7 +37,7 @@ class AppService:
|
||||
Get app list with pagination
|
||||
:param user_id: user id
|
||||
:param tenant_id: tenant id
|
||||
:param args: request args
|
||||
:param args: request args. Optional keys: status (e.g. "normal") restricts App.status.
|
||||
:return:
|
||||
"""
|
||||
filters = [App.tenant_id == tenant_id, App.is_universal == False]
|
||||
@ -53,6 +53,8 @@ class AppService:
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
|
||||
if args.get("status"):
|
||||
filters.append(App.status == args["status"])
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
if args.get("name"):
|
||||
|
||||
44
api/services/enterprise/app_permitted_service.py
Normal file
44
api/services/enterprise/app_permitted_service.py
Normal file
@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class PermittedAppsPage:
|
||||
app_ids: list[str]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
def list_permitted_apps(
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> PermittedAppsPage:
|
||||
try:
|
||||
body = EnterpriseService.WebAppAuth.list_externally_accessible_apps(
|
||||
page=page, limit=limit, mode=mode, name=name
|
||||
)
|
||||
except EnterpriseAPIError as exc:
|
||||
logger.warning(
|
||||
"permitted_apps EE call failed: status=%s message=%s",
|
||||
getattr(exc, "status_code", None),
|
||||
str(exc),
|
||||
)
|
||||
raise ServiceUnavailable("permitted_apps_unavailable") from exc
|
||||
|
||||
return PermittedAppsPage(
|
||||
app_ids=[row["appId"] for row in body.get("data", [])],
|
||||
total=int(body.get("total", 0)),
|
||||
has_more=bool(body.get("hasMore", False)),
|
||||
)
|
||||
@ -106,6 +106,15 @@ class EnterpriseService:
|
||||
def get_workspace_info(cls, tenant_id: str):
|
||||
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
|
||||
|
||||
@classmethod
|
||||
def initiate_device_flow_sso(cls, signed_state: str) -> dict:
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/device-flow/sso-initiate",
|
||||
json={"signed_state": signed_state},
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
|
||||
"""
|
||||
@ -234,6 +243,32 @@ class EnterpriseService:
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def list_externally_accessible_apps(
|
||||
cls,
|
||||
*,
|
||||
page: int,
|
||||
limit: int,
|
||||
mode: str | None = None,
|
||||
name: str | None = None,
|
||||
) -> dict:
|
||||
"""Call EE InnerListExternallyAccessibleApps; returns raw camelCase response.
|
||||
|
||||
Response shape: ``{"data": [{"appId", "tenantId", "mode", "name", "updatedAt"}],
|
||||
"total": int, "hasMore": bool}``.
|
||||
"""
|
||||
body: dict[str, str | int] = {"page": page, "limit": limit}
|
||||
if mode is not None:
|
||||
body["mode"] = mode
|
||||
if name is not None:
|
||||
body["name"] = name
|
||||
return EnterpriseRequest.send_request(
|
||||
"POST",
|
||||
"/webapp/externally-accessible-apps",
|
||||
json=body,
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
467
api/services/oauth_device_flow.py
Normal file
467
api/services/oauth_device_flow.py
Normal file
@ -0,0 +1,467 @@
|
||||
"""Device-flow service layer: Redis state machine, OAuth token mint
|
||||
(DB upsert + plaintext generation), and TTL policy. Specs:
|
||||
docs/specs/v1.0/server/{device-flow.md, tokens.md}.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.orm import Session, scoped_session
|
||||
|
||||
from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT
|
||||
from models.oauth import OAuthAccessToken
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Redis state machine — device_code + user_code ephemeral state
|
||||
# ============================================================================
|
||||
|
||||
|
||||
_DEVICE_CODE_KEY_PREFIX = "device_code:"
|
||||
_USER_CODE_KEY_PREFIX = "user_code:"
|
||||
DEVICE_CODE_KEY_FMT = _DEVICE_CODE_KEY_PREFIX + "{code}"
|
||||
USER_CODE_KEY_FMT = _USER_CODE_KEY_PREFIX + "{code}"
|
||||
|
||||
# Atomic GET → status-check → DEL(both keys). Two concurrent pollers must
|
||||
# not both observe APPROVED — only the winner gets the plaintext token,
|
||||
# the loser sees nil and the caller maps that to expired_token.
|
||||
_CONSUME_ON_POLL_LUA = """
|
||||
local raw = redis.call('GET', KEYS[1])
|
||||
if not raw then return nil end
|
||||
local ok, decoded = pcall(cjson.decode, raw)
|
||||
if not ok then return nil end
|
||||
if decoded.status == 'pending' then return nil end
|
||||
if decoded.user_code then
|
||||
redis.call('DEL', ARGV[1] .. decoded.user_code)
|
||||
end
|
||||
redis.call('DEL', KEYS[1])
|
||||
return raw
|
||||
"""
|
||||
|
||||
DEVICE_FLOW_TTL_SECONDS = 15 * 60 # RFC 8628 expires_in
|
||||
APPROVED_TTL_SECONDS_MIN = 60 # plaintext-token lifetime floor
|
||||
|
||||
USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXY3456789" # ambiguous chars dropped
|
||||
USER_CODE_SEGMENT_LEN = 4
|
||||
USER_CODE_MAX_CLAIM_ATTEMPTS = 5
|
||||
|
||||
DEFAULT_POLL_INTERVAL_SECONDS = 5 # RFC 8628 minimum
|
||||
|
||||
|
||||
class DeviceFlowStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
APPROVED = "approved"
|
||||
DENIED = "denied"
|
||||
|
||||
|
||||
class SlowDownDecision(StrEnum):
|
||||
OK = "ok"
|
||||
SLOW_DOWN = "slow_down"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceFlowState:
|
||||
"""``minted_token`` is plaintext between approve and the next poll;
|
||||
DEL'd after the poll reads it.
|
||||
"""
|
||||
|
||||
user_code: str
|
||||
client_id: str
|
||||
device_label: str
|
||||
status: DeviceFlowStatus
|
||||
subject_email: str | None = None
|
||||
account_id: str | None = None
|
||||
subject_issuer: str | None = None
|
||||
minted_token: str | None = None
|
||||
token_id: str | None = None
|
||||
created_at: str = ""
|
||||
created_ip: str = ""
|
||||
last_poll_at: str = ""
|
||||
poll_payload: dict | None = field(default=None)
|
||||
|
||||
def to_json(self) -> str:
|
||||
return json.dumps(asdict(self))
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, raw: str) -> DeviceFlowState:
|
||||
data = json.loads(raw)
|
||||
if "status" in data:
|
||||
data["status"] = DeviceFlowStatus(data["status"])
|
||||
return cls(**data)
|
||||
|
||||
|
||||
def _random_device_code() -> str:
|
||||
return "dc_" + secrets.token_urlsafe(24)
|
||||
|
||||
|
||||
def _random_user_code_segment() -> str:
|
||||
return "".join(secrets.choice(USER_CODE_ALPHABET) for _ in range(USER_CODE_SEGMENT_LEN))
|
||||
|
||||
|
||||
def _random_user_code() -> str:
|
||||
return f"{_random_user_code_segment()}-{_random_user_code_segment()}"
|
||||
|
||||
|
||||
class StateNotFoundError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidTransitionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UserCodeExhaustedError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceFlowRedis:
|
||||
def __init__(self, redis_client) -> None:
|
||||
self._redis = redis_client
|
||||
self._consume_on_poll_script = redis_client.register_script(_CONSUME_ON_POLL_LUA)
|
||||
|
||||
def start(self, client_id: str, device_label: str, created_ip: str) -> tuple[str, str, int]:
|
||||
device_code = _random_device_code()
|
||||
user_code = self._claim_user_code(device_code)
|
||||
state = DeviceFlowState(
|
||||
user_code=user_code,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
status=DeviceFlowStatus.PENDING,
|
||||
created_at=datetime.now(UTC).isoformat(),
|
||||
created_ip=created_ip,
|
||||
)
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
DEVICE_FLOW_TTL_SECONDS,
|
||||
state.to_json(),
|
||||
)
|
||||
return device_code, user_code, DEVICE_FLOW_TTL_SECONDS
|
||||
|
||||
def _claim_user_code(self, device_code: str) -> str:
|
||||
for _ in range(USER_CODE_MAX_CLAIM_ATTEMPTS):
|
||||
user_code = _random_user_code()
|
||||
key = USER_CODE_KEY_FMT.format(code=user_code)
|
||||
ok = self._redis.set(key, device_code, nx=True, ex=DEVICE_FLOW_TTL_SECONDS)
|
||||
if ok:
|
||||
return user_code
|
||||
raise UserCodeExhaustedError("could not allocate a unique user_code in 5 attempts")
|
||||
|
||||
def load_by_user_code(self, user_code: str) -> tuple[str, DeviceFlowState] | None:
|
||||
raw_dc = self._redis.get(USER_CODE_KEY_FMT.format(code=user_code))
|
||||
if not raw_dc:
|
||||
return None
|
||||
device_code = raw_dc.decode() if isinstance(raw_dc, (bytes, bytearray)) else raw_dc
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
return None
|
||||
return device_code, state
|
||||
|
||||
def load_by_device_code(self, device_code: str) -> DeviceFlowState | None:
|
||||
return self._load_state(device_code)
|
||||
|
||||
def _load_state(self, device_code: str) -> DeviceFlowState | None:
|
||||
raw = self._redis.get(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if not raw:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state for %s", device_code)
|
||||
return None
|
||||
|
||||
def approve(
|
||||
self,
|
||||
device_code: str,
|
||||
subject_email: str,
|
||||
account_id: str | None,
|
||||
minted_token: str,
|
||||
token_id: str,
|
||||
subject_issuer: str | None = None,
|
||||
poll_payload: dict | None = None,
|
||||
) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot approve {state.status}")
|
||||
|
||||
state.status = DeviceFlowStatus.APPROVED
|
||||
state.subject_email = subject_email
|
||||
state.account_id = account_id
|
||||
state.subject_issuer = subject_issuer
|
||||
state.minted_token = minted_token
|
||||
state.token_id = token_id
|
||||
state.poll_payload = poll_payload
|
||||
|
||||
new_ttl = self._remaining_ttl(device_code, floor=APPROVED_TTL_SECONDS_MIN)
|
||||
self._redis.setex(DEVICE_CODE_KEY_FMT.format(code=device_code), new_ttl, state.to_json())
|
||||
|
||||
def deny(self, device_code: str) -> None:
|
||||
state = self._load_state(device_code)
|
||||
if state is None:
|
||||
raise StateNotFoundError(device_code)
|
||||
if state.status is not DeviceFlowStatus.PENDING:
|
||||
raise InvalidTransitionError(f"cannot deny {state.status}")
|
||||
state.status = DeviceFlowStatus.DENIED
|
||||
self._redis.setex(
|
||||
DEVICE_CODE_KEY_FMT.format(code=device_code),
|
||||
self._remaining_ttl(device_code, floor=1),
|
||||
state.to_json(),
|
||||
)
|
||||
|
||||
def consume_on_poll(self, device_code: str) -> DeviceFlowState | None:
|
||||
"""Race-safe via Lua EVAL: GET + status-check + DEL execute in a
|
||||
single Redis transaction so only one of N concurrent pollers
|
||||
observes the APPROVED state. Losers get None, mapped to
|
||||
expired_token by the caller.
|
||||
"""
|
||||
raw = self._consume_on_poll_script(
|
||||
keys=[DEVICE_CODE_KEY_FMT.format(code=device_code)],
|
||||
args=[_USER_CODE_KEY_PREFIX],
|
||||
)
|
||||
if raw is None:
|
||||
return None
|
||||
text_ = raw.decode() if isinstance(raw, (bytes, bytearray)) else raw
|
||||
try:
|
||||
return DeviceFlowState.from_json(text_)
|
||||
except (ValueError, KeyError):
|
||||
logger.exception("device_flow: corrupt state on consume %s", device_code)
|
||||
return None
|
||||
|
||||
def record_poll(self, device_code: str, interval_seconds: int) -> SlowDownDecision:
|
||||
now = time.time()
|
||||
key = f"device_code:{device_code}:last_poll"
|
||||
prev_raw = self._redis.get(key)
|
||||
self._redis.setex(key, DEVICE_FLOW_TTL_SECONDS, str(now))
|
||||
if prev_raw is None:
|
||||
return SlowDownDecision.OK
|
||||
prev_s = prev_raw.decode() if isinstance(prev_raw, (bytes, bytearray)) else prev_raw
|
||||
try:
|
||||
prev = float(prev_s)
|
||||
except ValueError:
|
||||
return SlowDownDecision.OK
|
||||
if now - prev < interval_seconds:
|
||||
return SlowDownDecision.SLOW_DOWN
|
||||
return SlowDownDecision.OK
|
||||
|
||||
def _remaining_ttl(self, device_code: str, floor: int) -> int:
|
||||
"""``max(remaining, floor)`` — guarantees the CLI has at least
|
||||
``floor`` seconds to poll after a near-expiry approve.
|
||||
"""
|
||||
ttl = self._redis.ttl(DEVICE_CODE_KEY_FMT.format(code=device_code))
|
||||
if ttl is None or ttl < 0:
|
||||
return floor
|
||||
return max(int(ttl), floor)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Token mint — generate + upsert
|
||||
# ============================================================================
|
||||
|
||||
|
||||
OAUTH_BODY_BYTES = 32 # ~256 bits entropy
|
||||
PREFIX_OAUTH_ACCOUNT = "dfoa_"
|
||||
PREFIX_OAUTH_EXTERNAL_SSO = "dfoe_"
|
||||
|
||||
# Sentinel issuer for account-flow rows. Postgres' default partial unique
|
||||
# index treats NULLs as distinct, which would let two live `dfoa_` rows
|
||||
# share (email, client, device) and break rotate-in-place. Storing a
|
||||
# non-empty literal makes the composite key collide as intended.
|
||||
ACCOUNT_ISSUER_SENTINEL = "dify:account"
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class MintResult:
|
||||
"""Plaintext token surfaces to the caller once."""
|
||||
|
||||
token: str
|
||||
token_id: uuid.UUID
|
||||
expires_at: datetime
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class UpsertOutcome:
|
||||
token_id: uuid.UUID
|
||||
rotated: bool
|
||||
old_hash: str | None
|
||||
|
||||
|
||||
def generate_token(prefix: str) -> str:
|
||||
return prefix + secrets.token_urlsafe(OAUTH_BODY_BYTES)
|
||||
|
||||
|
||||
def sha256_hex(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
def mint_oauth_token(
|
||||
# Accept either Session or Flask-SQLAlchemy's request-scoped wrapper —
|
||||
# the wrapper proxies the same execute/commit surface.
|
||||
session: Session | scoped_session,
|
||||
redis_client,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
ttl_days: int,
|
||||
) -> MintResult:
|
||||
"""Live row rotates in place via partial unique index
|
||||
``uq_oauth_active_per_device``; hard-expired rows are excluded by the
|
||||
index predicate so re-login INSERTs fresh. Pre-rotate Redis entry is
|
||||
deleted so stale AuthContext drops immediately.
|
||||
"""
|
||||
if prefix == PREFIX_OAUTH_ACCOUNT:
|
||||
# Account flow always writes the sentinel — caller may pass None
|
||||
# (for clarity) or the sentinel itself; nothing else is valid.
|
||||
if subject_issuer not in (None, ACCOUNT_ISSUER_SENTINEL):
|
||||
raise ValueError(f"account-flow token must use ACCOUNT_ISSUER_SENTINEL, got {subject_issuer!r}")
|
||||
subject_issuer = ACCOUNT_ISSUER_SENTINEL
|
||||
elif prefix == PREFIX_OAUTH_EXTERNAL_SSO:
|
||||
# Defense in depth: enterprise canonicalises + rejects empty,
|
||||
# but a regression there must not yield a NULL composite key here.
|
||||
if not subject_issuer or not subject_issuer.strip():
|
||||
raise ValueError("external-SSO token requires non-empty subject_issuer")
|
||||
else:
|
||||
raise ValueError(f"unknown oauth prefix: {prefix!r}")
|
||||
|
||||
token = generate_token(prefix)
|
||||
new_hash = sha256_hex(token)
|
||||
expires_at = datetime.now(UTC) + timedelta(days=ttl_days)
|
||||
|
||||
outcome = _upsert(
|
||||
session,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
new_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
|
||||
if outcome.rotated and outcome.old_hash:
|
||||
redis_client.delete(TOKEN_CACHE_KEY_FMT.format(hash=outcome.old_hash))
|
||||
|
||||
return MintResult(token=token, token_id=outcome.token_id, expires_at=expires_at)
|
||||
|
||||
|
||||
def _upsert(
|
||||
session: Session | scoped_session,
|
||||
*,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
account_id: str | None,
|
||||
client_id: str,
|
||||
device_label: str,
|
||||
prefix: str,
|
||||
new_hash: str,
|
||||
expires_at: datetime,
|
||||
) -> UpsertOutcome:
|
||||
# Snapshot prior live row's hash for Redis invalidation post-rotate.
|
||||
# subject_issuer is always non-null here (account flow uses sentinel,
|
||||
# external-SSO is validated upstream), so equality matches the index.
|
||||
prior = session.execute(
|
||||
select(OAuthAccessToken.id, OAuthAccessToken.token_hash)
|
||||
.where(
|
||||
OAuthAccessToken.subject_email == subject_email,
|
||||
OAuthAccessToken.subject_issuer == subject_issuer,
|
||||
OAuthAccessToken.client_id == client_id,
|
||||
OAuthAccessToken.device_label == device_label,
|
||||
OAuthAccessToken.revoked_at.is_(None),
|
||||
)
|
||||
.limit(1)
|
||||
).first()
|
||||
old_hash = prior.token_hash if prior else None
|
||||
|
||||
insert_stmt = pg_insert(OAuthAccessToken).values(
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
account_id=account_id,
|
||||
client_id=client_id,
|
||||
device_label=device_label,
|
||||
prefix=prefix,
|
||||
token_hash=new_hash,
|
||||
expires_at=expires_at,
|
||||
)
|
||||
upsert_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["subject_email", "subject_issuer", "client_id", "device_label"],
|
||||
index_where=OAuthAccessToken.revoked_at.is_(None),
|
||||
set_={
|
||||
"token_hash": insert_stmt.excluded.token_hash,
|
||||
"prefix": insert_stmt.excluded.prefix,
|
||||
"account_id": insert_stmt.excluded.account_id,
|
||||
"expires_at": insert_stmt.excluded.expires_at,
|
||||
"created_at": func.now(),
|
||||
"last_used_at": None,
|
||||
},
|
||||
).returning(OAuthAccessToken.id)
|
||||
row = session.execute(upsert_stmt).first()
|
||||
session.commit()
|
||||
|
||||
if row is None:
|
||||
raise RuntimeError("oauth_token upsert returned no row")
|
||||
token_id = uuid.UUID(str(row.id))
|
||||
return UpsertOutcome(
|
||||
token_id=token_id,
|
||||
rotated=prior is not None,
|
||||
old_hash=old_hash,
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TTL policy — days new OAuth tokens live
|
||||
# ============================================================================
|
||||
|
||||
|
||||
DEFAULT_OAUTH_TTL_DAYS = 14
|
||||
MIN_TTL_DAYS = 1
|
||||
MAX_TTL_DAYS = 365
|
||||
|
||||
_TTL_ENV_VAR = "OAUTH_TTL_DAYS"
|
||||
|
||||
|
||||
def oauth_ttl_days(tenant_id: str | None = None) -> int:
|
||||
"""``OAUTH_TTL_DAYS`` env, else default. EE tenant-level lookup
|
||||
is deferred; when it lands it wins over the env (Redis-cached 60s).
|
||||
"""
|
||||
_ = tenant_id
|
||||
|
||||
raw = os.environ.get(_TTL_ENV_VAR)
|
||||
if raw is None:
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
try:
|
||||
value = int(raw)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"%s=%r is not an int; falling back to %d",
|
||||
_TTL_ENV_VAR,
|
||||
raw,
|
||||
DEFAULT_OAUTH_TTL_DAYS,
|
||||
)
|
||||
return DEFAULT_OAUTH_TTL_DAYS
|
||||
if value < MIN_TTL_DAYS:
|
||||
logger.warning("%s=%d below min %d; clamping", _TTL_ENV_VAR, value, MIN_TTL_DAYS)
|
||||
return MIN_TTL_DAYS
|
||||
if value > MAX_TTL_DAYS:
|
||||
logger.warning("%s=%d above max %d; clamping", _TTL_ENV_VAR, value, MAX_TTL_DAYS)
|
||||
return MAX_TTL_DAYS
|
||||
return value
|
||||
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
125
api/tests/integration_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,125 @@
|
||||
"""Shared fixtures for /openapi/v1/* integration tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models import Account, App, OAuthAccessToken, Tenant, TenantAccountJoin
|
||||
from models.account import AccountStatus
|
||||
|
||||
|
||||
def _sha256(token: str) -> str:
|
||||
return hashlib.sha256(token.encode("utf-8")).hexdigest()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def disable_enterprise(monkeypatch):
|
||||
"""Default to CE behaviour for /openapi/v1 tests. Tests that exercise the
|
||||
EE branch override this with their own monkeypatch in-test."""
|
||||
from configs import dify_config
|
||||
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def workspace_account(flask_app: Flask) -> Generator[tuple[Account, Tenant, TenantAccountJoin], None, None]:
|
||||
with flask_app.app_context():
|
||||
tenant = Tenant(name="t1", status="normal")
|
||||
account = Account(email="u@example.com", name="u")
|
||||
db.session.add_all([tenant, account])
|
||||
db.session.commit()
|
||||
account.status = AccountStatus.ACTIVE
|
||||
join = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role="owner")
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
yield account, tenant, join
|
||||
db.session.delete(join)
|
||||
db.session.delete(account)
|
||||
db.session.delete(tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_in_workspace(flask_app: Flask, workspace_account) -> Generator[App, None, None]:
|
||||
_, tenant, _ = workspace_account
|
||||
with flask_app.app_context():
|
||||
app = App(tenant_id=tenant.id, name="a", mode="chat", status="normal", enable_site=True, enable_api=True)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mint_token(flask_app: Flask):
|
||||
"""Factory fixture; tracks minted rows and deletes them on teardown so
|
||||
the auth-related test runs don't accumulate `oauth_access_tokens` rows."""
|
||||
minted: list[OAuthAccessToken] = []
|
||||
|
||||
def _mint(
|
||||
token: str,
|
||||
*,
|
||||
account_id: str | None,
|
||||
prefix: str,
|
||||
subject_email: str,
|
||||
subject_issuer: str | None,
|
||||
) -> OAuthAccessToken:
|
||||
with flask_app.app_context():
|
||||
row = OAuthAccessToken(
|
||||
token_hash=_sha256(token),
|
||||
prefix=prefix,
|
||||
account_id=account_id,
|
||||
subject_email=subject_email,
|
||||
subject_issuer=subject_issuer,
|
||||
client_id="difyctl",
|
||||
device_label="test-device",
|
||||
expires_at=datetime.now(UTC) + timedelta(hours=1),
|
||||
)
|
||||
db.session.add(row)
|
||||
db.session.commit()
|
||||
minted.append(row)
|
||||
return row
|
||||
|
||||
yield _mint
|
||||
|
||||
with flask_app.app_context():
|
||||
for row in minted:
|
||||
db.session.delete(db.session.merge(row))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def account_token(workspace_account, mint_token) -> str:
|
||||
account, _, _ = workspace_account
|
||||
token = "dfoa_" + uuid.uuid4().hex
|
||||
mint_token(
|
||||
token,
|
||||
account_id=account.id,
|
||||
prefix="dfoa_",
|
||||
subject_email=account.email,
|
||||
subject_issuer="dify:account",
|
||||
)
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _flush_auth_redis(flask_app: Flask) -> Generator[None, None, None]:
|
||||
def _flush():
|
||||
with flask_app.app_context():
|
||||
for k in redis_client.keys("auth:*"):
|
||||
redis_client.delete(k)
|
||||
for k in redis_client.keys("rl:*"):
|
||||
redis_client.delete(k)
|
||||
|
||||
_flush()
|
||||
yield
|
||||
_flush()
|
||||
252
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
252
api/tests/integration_tests/controllers/openapi/test_app_run.py
Normal file
@ -0,0 +1,252 @@
|
||||
"""Integration tests for POST /openapi/v1/apps/<id>/run."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from extensions.ext_database import db
|
||||
from models import App
|
||||
|
||||
|
||||
def test_run_chat_dispatches_to_chat_handler(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
captured = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
captured["invoke_from"] = invoke_from
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"conversation_id": "c",
|
||||
"mode": "chat",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
|
||||
)
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking", "user": "spoof@x.com"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "chat"
|
||||
assert captured["mode"] == "chat"
|
||||
assert captured["invoke_from"] == InvokeFrom.OPENAPI
|
||||
assert "user" not in captured["args"], "server must strip body.user; identity comes from bearer"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_mode(flask_app: Flask, workspace_account):
|
||||
"""Factory that creates an App row in the workspace_account tenant with
|
||||
a specified mode. Tracks rows for teardown.
|
||||
"""
|
||||
_, tenant, _ = workspace_account
|
||||
created: list[App] = []
|
||||
|
||||
def _make(mode: str) -> App:
|
||||
with flask_app.app_context():
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name=f"a-{mode}",
|
||||
mode=mode,
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db.session.refresh(app)
|
||||
db.session.expunge(app)
|
||||
created.append(app)
|
||||
return app
|
||||
|
||||
yield _make
|
||||
|
||||
with flask_app.app_context():
|
||||
for app in created:
|
||||
db.session.delete(db.session.merge(app))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_run_chat_without_query_returns_422(flask_app, account_token, app_in_workspace, monkeypatch):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_required_for_chat" in res.data
|
||||
|
||||
|
||||
def test_run_completion_dispatches_to_completion_handler(
|
||||
flask_app, account_token, app_with_mode, monkeypatch
|
||||
):
|
||||
app = app_with_mode("completion")
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
captured["mode"] = app_model.mode
|
||||
captured["args"] = args
|
||||
return {
|
||||
"event": "message",
|
||||
"task_id": "t",
|
||||
"id": "m",
|
||||
"message_id": "m",
|
||||
"mode": "completion",
|
||||
"answer": "ok",
|
||||
"created_at": 0,
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
|
||||
)
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.get_json()["mode"] == "completion"
|
||||
assert captured["mode"] == "completion"
|
||||
|
||||
|
||||
def test_run_workflow_with_query_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("workflow")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"query_not_supported_for_workflow" in res.data
|
||||
|
||||
|
||||
def test_run_workflow_no_query_dispatches_to_workflow_handler(
|
||||
flask_app, account_token, app_with_mode, monkeypatch
|
||||
):
|
||||
app = app_with_mode("workflow")
|
||||
|
||||
def _fake_generate(*, app_model, user, args, invoke_from, streaming):
|
||||
return {
|
||||
"workflow_run_id": "wfr",
|
||||
"task_id": "t",
|
||||
"data": {"id": "wf-d", "workflow_id": "wf", "status": "succeeded"},
|
||||
}
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate", staticmethod(_fake_generate)
|
||||
)
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.get_json()
|
||||
assert body["mode"] == "workflow"
|
||||
assert body["workflow_run_id"] == "wfr"
|
||||
|
||||
|
||||
def test_run_unsupported_mode_returns_422(flask_app, account_token, app_with_mode, monkeypatch):
|
||||
app = app_with_mode("channel")
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app.id}/run",
|
||||
json={"inputs": {}, "response_mode": "blocking"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
assert b"mode_not_runnable" in res.data
|
||||
|
||||
|
||||
def test_run_without_bearer_returns_401(flask_app, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
)
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_run_with_insufficient_scope_returns_403(
|
||||
flask_app, account_token, app_in_workspace, monkeypatch
|
||||
):
|
||||
"""Stub the authenticator to return an AuthContext with empty scopes."""
|
||||
from libs import oauth_bearer
|
||||
|
||||
real_authenticate = oauth_bearer.BearerAuthenticator.authenticate
|
||||
|
||||
def _stub_authenticate(self, token: str):
|
||||
ctx = real_authenticate(self, token)
|
||||
from dataclasses import replace
|
||||
|
||||
return replace(ctx, scopes=frozenset())
|
||||
|
||||
monkeypatch.setattr(oauth_bearer.BearerAuthenticator, "authenticate", _stub_authenticate)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
|
||||
|
||||
def test_run_with_unknown_app_returns_404(flask_app, account_token):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{uuid.uuid4()}/run",
|
||||
json={"inputs": {}, "query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_run_streaming_returns_event_stream(
|
||||
flask_app, account_token, app_in_workspace, monkeypatch
|
||||
):
|
||||
def _stream() -> Generator[str, None, None]:
|
||||
yield "event: message\ndata: {\"x\": 1}\n\n"
|
||||
|
||||
monkeypatch.setattr(
|
||||
"controllers.openapi.app_run.AppGenerateService.generate",
|
||||
staticmethod(lambda **kw: _stream()),
|
||||
)
|
||||
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"inputs": {}, "query": "hi", "response_mode": "streaming"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.headers["Content-Type"].startswith("text/event-stream")
|
||||
assert b"event: message" in res.data
|
||||
|
||||
|
||||
def test_run_without_inputs_returns_422(flask_app, account_token, app_in_workspace):
|
||||
client = flask_app.test_client()
|
||||
res = client.post(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/run",
|
||||
json={"query": "hi"},
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
210
api/tests/integration_tests/controllers/openapi/test_apps.py
Normal file
@ -0,0 +1,210 @@
|
||||
"""Integration tests for /openapi/v1/apps* read surface."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from models import App
|
||||
|
||||
|
||||
def test_apps_bare_id_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_parameters_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_info_route_404(test_client, app_in_workspace, account_token):
|
||||
resp = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_apps_describe_returns_merged_shape(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"]["id"] == app_in_workspace.id
|
||||
assert body["info"]["mode"] == "chat"
|
||||
assert isinstance(body["parameters"], dict)
|
||||
|
||||
|
||||
def test_apps_describe_full_includes_input_schema(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is not None
|
||||
assert body["input_schema"]["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
|
||||
|
||||
def test_apps_describe_fields_info_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_parameters_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=parameters",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is not None
|
||||
assert body["input_schema"] is None
|
||||
|
||||
|
||||
def test_apps_describe_fields_input_schema_only(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_combined(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info,input_schema",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["info"] is not None
|
||||
assert body["parameters"] is None
|
||||
assert body["input_schema"] is not None
|
||||
|
||||
|
||||
def test_apps_describe_fields_unknown_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=garbage",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_describe_fields_extra_param_returns_422(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/describe?fields=info&page=1",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 422
|
||||
|
||||
|
||||
def test_apps_list_returns_pagination_envelope(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&page=1&limit=20",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
assert body["page"] == 1
|
||||
assert body["limit"] == 20
|
||||
assert body["total"] >= 1
|
||||
assert any(d["id"] == app_in_workspace.id for d in body["data"])
|
||||
|
||||
|
||||
def test_apps_list_requires_workspace_id(test_client: FlaskClient, account_token: str):
|
||||
res = test_client.get("/openapi/v1/apps", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_apps_list_tag_no_match_returns_empty_data_not_400(
|
||||
test_client: FlaskClient,
|
||||
workspace_account,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
):
|
||||
_, tenant, _ = workspace_account
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps?workspace_id={tenant.id}&tag=nonexistent",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["data"] == []
|
||||
|
||||
|
||||
def test_account_sessions_returns_envelope(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
):
|
||||
res = test_client.get("/openapi/v1/account/sessions", headers={"Authorization": f"Bearer {account_token}"})
|
||||
assert res.status_code == 200
|
||||
body = res.json
|
||||
# canonical envelope shape
|
||||
assert isinstance(body["data"], list)
|
||||
assert "page" in body
|
||||
assert "limit" in body
|
||||
assert "total" in body
|
||||
assert "has_more" in body
|
||||
# the bearer's own minted session must appear
|
||||
assert any(s["prefix"] == "dfoa_" for s in body["data"])
|
||||
# legacy "sessions" key must NOT appear
|
||||
assert "sessions" not in body
|
||||
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
127
api/tests/integration_tests/controllers/openapi/test_auth.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Integration tests for the /openapi/v1 bearer auth surface.
|
||||
|
||||
Layer 0 (workspace membership), per-token rate limit, and read-scope (`apps:read`)
|
||||
acceptance/rejection on app-scoped routes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models import App, Tenant
|
||||
|
||||
|
||||
def test_info_accepts_account_bearer_with_apps_read_scope(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{app_in_workspace.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json["id"] == app_in_workspace.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_workspace_app(flask_app: Flask) -> Generator[App, None, None]:
|
||||
"""A fresh app under a *different* tenant — caller has no membership row."""
|
||||
with flask_app.app_context():
|
||||
other_tenant = Tenant(name="other", status="normal")
|
||||
db.session.add(other_tenant)
|
||||
db.session.commit()
|
||||
app = App(
|
||||
tenant_id=other_tenant.id,
|
||||
name="b",
|
||||
mode="chat",
|
||||
status="normal",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
yield app
|
||||
db.session.delete(app)
|
||||
db.session.delete(other_tenant)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def test_layer0_denies_account_bearer_without_membership(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
) -> None:
|
||||
"""Account A bearer hitting an app under tenant B — Layer 0 denies on CE."""
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 403
|
||||
assert res.json.get("message") == "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_layer0_skipped_when_enterprise_enabled(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
other_workspace_app: App,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""On EE, Layer 0 short-circuits — gateway RBAC owns tenant isolation.
|
||||
|
||||
/info uses validate_bearer + require_workspace_member inline (no
|
||||
AppAuthzCheck), so a cross-tenant bearer reaches the app lookup and
|
||||
gets 200 — gateway is expected to enforce isolation upstream.
|
||||
"""
|
||||
from configs import dify_config
|
||||
|
||||
# Override the conftest autouse default for this test only.
|
||||
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True)
|
||||
|
||||
res = test_client.get(
|
||||
f"/openapi/v1/apps/{other_workspace_app.id}/info",
|
||||
headers={"Authorization": f"Bearer {account_token}"},
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json.get("message") != "workspace_membership_revoked"
|
||||
|
||||
|
||||
def test_rate_limit_returns_429_after_60_requests(
|
||||
test_client: FlaskClient,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""61st sequential GET to /account on the same bearer → 429 with Retry-After."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for i in range(60):
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 200, f"unexpected fail at i={i}"
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
assert r.headers.get("Retry-After"), "Retry-After header missing"
|
||||
assert int(r.headers["Retry-After"]) >= 1
|
||||
body = r.json or {}
|
||||
assert body.get("error") == "rate_limited"
|
||||
assert isinstance(body.get("retry_after_ms"), int)
|
||||
assert body["retry_after_ms"] >= 1000
|
||||
|
||||
|
||||
def test_rate_limit_bucket_shared_across_surfaces(
|
||||
test_client: FlaskClient,
|
||||
app_in_workspace: App,
|
||||
account_token: str,
|
||||
) -> None:
|
||||
"""30 calls to /account + 30 calls to /apps/<id>/info on same token → 61st 429s."""
|
||||
headers = {"Authorization": f"Bearer {account_token}"}
|
||||
for _ in range(30):
|
||||
assert test_client.get("/openapi/v1/account", headers=headers).status_code == 200
|
||||
for _ in range(30):
|
||||
assert test_client.get(f"/openapi/v1/apps/{app_in_workspace.id}/info", headers=headers).status_code == 200
|
||||
|
||||
r = test_client.get("/openapi/v1/account", headers=headers)
|
||||
assert r.status_code == 429
|
||||
@ -0,0 +1,54 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
from controllers.openapi.auth.steps import (
|
||||
AppAuthzCheck,
|
||||
AppResolver,
|
||||
BearerCheck,
|
||||
CallerMount,
|
||||
ScopeCheck,
|
||||
WorkspaceMembershipCheck,
|
||||
)
|
||||
from controllers.openapi.auth.strategies import (
|
||||
AccountMounter,
|
||||
AclStrategy,
|
||||
EndUserMounter,
|
||||
MembershipStrategy,
|
||||
)
|
||||
|
||||
|
||||
def test_pipeline_is_composed():
|
||||
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
|
||||
|
||||
|
||||
def test_pipeline_step_order():
|
||||
"""BearerCheck → ScopeCheck → AppResolver → WorkspaceMembershipCheck →
|
||||
AppAuthzCheck → CallerMount. Rate-limit is enforced inside
|
||||
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
|
||||
steps = OAUTH_BEARER_PIPELINE._steps
|
||||
assert isinstance(steps[0], BearerCheck)
|
||||
assert isinstance(steps[1], ScopeCheck)
|
||||
assert isinstance(steps[2], AppResolver)
|
||||
assert isinstance(steps[3], WorkspaceMembershipCheck)
|
||||
assert isinstance(steps[4], AppAuthzCheck)
|
||||
assert isinstance(steps[5], CallerMount)
|
||||
|
||||
|
||||
def test_caller_mount_has_both_mounters():
|
||||
cm = OAUTH_BEARER_PIPELINE._steps[5]
|
||||
kinds = {type(m) for m in cm._mounters}
|
||||
assert AccountMounter in kinds
|
||||
assert EndUserMounter in kinds
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_acl_when_enabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = True
|
||||
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.composition.FeatureService")
|
||||
def test_strategy_resolver_picks_membership_when_disabled(fs):
|
||||
fs.get_system_features.return_value.webapp_auth.enabled = False
|
||||
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
|
||||
@ -0,0 +1,21 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
|
||||
|
||||
def test_context_starts_unpopulated():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
assert ctx.subject_type is None
|
||||
assert ctx.subject_email is None
|
||||
assert ctx.account_id is None
|
||||
assert ctx.scopes == frozenset()
|
||||
assert ctx.app is None
|
||||
assert ctx.tenant is None
|
||||
assert ctx.caller is None
|
||||
assert ctx.caller_kind is None
|
||||
|
||||
|
||||
def test_context_fields_are_mutable():
|
||||
ctx = Context(request=MagicMock(), required_scope="apps:run")
|
||||
ctx.scopes = frozenset({"full"})
|
||||
assert "full" in ctx.scopes
|
||||
@ -0,0 +1,61 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
def test_run_invokes_each_step_in_order():
|
||||
calls = []
|
||||
|
||||
class S:
|
||||
def __init__(self, tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(self, ctx):
|
||||
calls.append(self.tag)
|
||||
|
||||
Pipeline(S("a"), S("b"), S("c")).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == ["a", "b", "c"]
|
||||
|
||||
|
||||
def test_run_short_circuits_on_raise():
|
||||
calls = []
|
||||
|
||||
class Boom:
|
||||
def __call__(self, ctx):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
class Tail:
|
||||
def __call__(self, ctx):
|
||||
calls.append("ran")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
Pipeline(Boom(), Tail()).run(Context(request=MagicMock(), required_scope="x"))
|
||||
assert calls == []
|
||||
|
||||
|
||||
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
|
||||
seen = {}
|
||||
|
||||
class FakeStep:
|
||||
def __call__(self, ctx):
|
||||
ctx.app = "APP"
|
||||
ctx.caller = "CALLER"
|
||||
ctx.caller_kind = "account"
|
||||
|
||||
pipeline = Pipeline(FakeStep())
|
||||
|
||||
@pipeline.guard(scope="apps:run")
|
||||
def handler(app_model, caller, caller_kind):
|
||||
seen["app_model"] = app_model
|
||||
seen["caller"] = caller
|
||||
seen["caller_kind"] = caller_kind
|
||||
return "ok"
|
||||
|
||||
app = Flask(__name__)
|
||||
with app.test_request_context("/x", method="POST"):
|
||||
assert handler() == "ok"
|
||||
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
|
||||
@ -0,0 +1,64 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppResolver
|
||||
from models import TenantStatus
|
||||
|
||||
|
||||
def _ctx(view_args):
|
||||
req = MagicMock()
|
||||
req.view_args = view_args
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def _app(*, status="normal", enable_api=True):
|
||||
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
|
||||
|
||||
|
||||
def _tenant(*, status=TenantStatus.NORMAL):
|
||||
return SimpleNamespace(id="t1", status=status)
|
||||
|
||||
|
||||
def test_resolver_rejects_missing_path_param():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx({}))
|
||||
|
||||
|
||||
def test_resolver_rejects_none_view_args():
|
||||
with pytest.raises(BadRequest):
|
||||
AppResolver()(_ctx(None))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_404_when_app_missing(db):
|
||||
db.session.get.side_effect = [None]
|
||||
with pytest.raises(NotFound):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_disabled(db):
|
||||
db.session.get.side_effect = [_app(enable_api=False)]
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
assert "service_api_disabled" in str(exc.value.description)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_403_when_tenant_archived(db):
|
||||
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
|
||||
with pytest.raises(Forbidden):
|
||||
AppResolver()(_ctx({"app_id": "x"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.db")
|
||||
def test_resolver_populates_app_and_tenant(db):
|
||||
db.session.get.side_effect = [_app(), _tenant()]
|
||||
ctx = _ctx({"app_id": "x"})
|
||||
AppResolver()(ctx)
|
||||
assert ctx.app.id == "app1"
|
||||
assert ctx.tenant.id == "t1"
|
||||
@ -0,0 +1,53 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import AppAuthzCheck
|
||||
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id="acc1"):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.subject_email = "alice@example.com"
|
||||
c.account_id = account_id
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies.EnterpriseService")
|
||||
def test_acl_strategy_calls_inner_api(ent):
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
|
||||
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
|
||||
user_id="alice@example.com",
|
||||
app_id="app1",
|
||||
)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._has_tenant_membership")
|
||||
def test_membership_strategy_uses_join_lookup(member):
|
||||
member.return_value = True
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
|
||||
member.assert_called_once_with("acc1", "t1")
|
||||
|
||||
|
||||
def test_membership_strategy_rejects_external_sso():
|
||||
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
|
||||
|
||||
|
||||
def test_app_authz_check_raises_when_strategy_denies():
|
||||
deny = SimpleNamespace(authorize=lambda c: False)
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
assert "subject_no_app_access" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_app_authz_check_passes_when_strategy_allows():
|
||||
allow = SimpleNamespace(authorize=lambda c: True)
|
||||
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,56 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import BearerCheck
|
||||
from libs.oauth_bearer import AuthContext, InvalidBearerError, Scope, SubjectType
|
||||
|
||||
|
||||
def _ctx(headers):
|
||||
req = MagicMock()
|
||||
req.headers = headers
|
||||
return Context(request=req, required_scope="apps:run")
|
||||
|
||||
|
||||
def test_bearer_check_rejects_missing_header():
|
||||
with pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_rejects_unknown_prefix(get_auth):
|
||||
get_auth.return_value.authenticate.side_effect = InvalidBearerError("unknown token prefix")
|
||||
with pytest.raises(Unauthorized):
|
||||
BearerCheck()(_ctx({"Authorization": "Bearer xxx_abc"}))
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.get_authenticator")
|
||||
def test_bearer_check_populates_context(get_auth):
|
||||
tok_id = uuid.uuid4()
|
||||
authn = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="a@x.com",
|
||||
subject_issuer=None,
|
||||
account_id=None,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=tok_id,
|
||||
source="oauth-account",
|
||||
expires_at=datetime.now(UTC),
|
||||
token_hash="hash-1",
|
||||
verified_tenants={},
|
||||
)
|
||||
get_auth.return_value.authenticate.return_value = authn
|
||||
|
||||
ctx = _ctx({"Authorization": "Bearer dfoa_abc"})
|
||||
BearerCheck()(ctx)
|
||||
|
||||
assert ctx.subject_type == SubjectType.ACCOUNT
|
||||
assert ctx.subject_email == "a@x.com"
|
||||
assert ctx.scopes == frozenset({Scope.FULL})
|
||||
assert ctx.source == "oauth-account"
|
||||
assert ctx.token_id == tok_id
|
||||
assert ctx.token_hash == "hash-1"
|
||||
@ -0,0 +1,157 @@
|
||||
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
|
||||
c = Context(request=MagicMock(), required_scope="apps:read")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
|
||||
c.cached_verified_tenants = cached_verified_tenants
|
||||
c.token_hash = token_hash
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def step():
|
||||
return WorkspaceMembershipCheck()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id=str(uuid.uuid4()),
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
account_id=None,
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": True},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={"t1": False},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_db.session.execute.assert_not_called()
|
||||
mock_record.assert_not_called()
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
step(ctx)
|
||||
mock_record.assert_called_once_with("hash-1", "t1", False)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.steps.dify_config")
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
ctx = _ctx(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
account_id="a1",
|
||||
tenant_id="t1",
|
||||
cached_verified_tenants={},
|
||||
token_hash="hash-1",
|
||||
)
|
||||
step(ctx) # no raise
|
||||
mock_record.assert_called_once_with("hash-1", "t1", True)
|
||||
@ -0,0 +1,77 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import CallerMount
|
||||
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from libs.oauth_bearer import SubjectType
|
||||
|
||||
|
||||
def _ctx(*, subject_type, account_id=None, subject_email=None):
|
||||
c = Context(request=MagicMock(), required_scope="apps:run")
|
||||
c.subject_type = subject_type
|
||||
c.account_id = account_id
|
||||
c.subject_email = subject_email
|
||||
c.app = SimpleNamespace(id="app1")
|
||||
c.tenant = SimpleNamespace(id="t1")
|
||||
return c
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.db")
|
||||
def test_account_mounter(db, login):
|
||||
account = SimpleNamespace()
|
||||
db.session.get.return_value = account
|
||||
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
|
||||
AccountMounter().mount(ctx)
|
||||
assert ctx.caller is account
|
||||
assert ctx.caller.current_tenant is ctx.tenant
|
||||
assert ctx.caller_kind == "account"
|
||||
login.assert_called_once_with(account)
|
||||
|
||||
|
||||
@patch("controllers.openapi.auth.strategies._login_as")
|
||||
@patch("controllers.openapi.auth.strategies.EndUserService")
|
||||
def test_end_user_mounter(svc, login):
|
||||
eu = SimpleNamespace()
|
||||
svc.get_or_create_end_user_by_type.return_value = eu
|
||||
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
|
||||
EndUserMounter().mount(ctx)
|
||||
svc.get_or_create_end_user_by_type.assert_called_once_with(
|
||||
InvokeFrom.OPENAPI,
|
||||
tenant_id="t1",
|
||||
app_id="app1",
|
||||
user_id="a@x.com",
|
||||
)
|
||||
assert ctx.caller is eu
|
||||
assert ctx.caller_kind == "end_user"
|
||||
|
||||
|
||||
def test_caller_mount_dispatches_by_subject_type():
|
||||
seen = {}
|
||||
|
||||
class Fake:
|
||||
def __init__(self, st, tag):
|
||||
self._st, self._tag = st, tag
|
||||
|
||||
def applies_to(self, st):
|
||||
return st == self._st
|
||||
|
||||
def mount(self, ctx):
|
||||
seen["who"] = self._tag
|
||||
|
||||
cm = CallerMount(
|
||||
Fake(SubjectType.ACCOUNT, "acct"),
|
||||
Fake(SubjectType.EXTERNAL_SSO, "sso"),
|
||||
)
|
||||
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
|
||||
assert seen == {"who": "sso"}
|
||||
|
||||
|
||||
def test_caller_mount_raises_when_none_applies():
|
||||
with pytest.raises(Unauthorized):
|
||||
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))
|
||||
@ -0,0 +1,27 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.openapi.auth.context import Context
|
||||
from controllers.openapi.auth.steps import ScopeCheck
|
||||
|
||||
|
||||
def _ctx(scopes, required):
|
||||
c = Context(request=MagicMock(), required_scope=required)
|
||||
c.scopes = frozenset(scopes)
|
||||
return c
|
||||
|
||||
|
||||
def test_scope_check_passes_on_full():
|
||||
ScopeCheck()(_ctx({"full"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_passes_on_explicit_match():
|
||||
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
|
||||
|
||||
|
||||
def test_scope_check_rejects_when_missing():
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
|
||||
assert "insufficient_scope" in str(exc.value.description)
|
||||
15
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
15
api/tests/unit_tests/controllers/openapi/conftest.py
Normal file
@ -0,0 +1,15 @@
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.auth.pipeline import Pipeline
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def bypass_pipeline(monkeypatch):
|
||||
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
|
||||
|
||||
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
|
||||
pipeline at import time; mocking the module attribute does not undo
|
||||
that. Patching Pipeline.run on the class is the bypass that actually
|
||||
works.
|
||||
"""
|
||||
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
|
||||
138
api/tests/unit_tests/controllers/openapi/test_account.py
Normal file
138
api/tests/unit_tests/controllers/openapi/test_account.py
Normal file
@ -0,0 +1,138 @@
|
||||
"""User-scoped identity + session endpoints under /openapi/v1/account."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.account import (
|
||||
AccountApi,
|
||||
AccountSessionByIdApi,
|
||||
AccountSessionsApi,
|
||||
AccountSessionsSelfApi,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_account_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account" in rules
|
||||
|
||||
|
||||
def test_account_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountApi
|
||||
|
||||
|
||||
def test_account_sessions_self_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/self" in rules
|
||||
|
||||
|
||||
def test_sessions_self_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsSelfApi
|
||||
|
||||
|
||||
def test_account_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account")
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_self_methods(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/self")
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_sessions_list_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions" in rules
|
||||
|
||||
|
||||
def test_sessions_list_dispatches_to_sessions_api(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionsApi
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_session_by_id_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/account/sessions/<string:session_id>" in rules
|
||||
|
||||
|
||||
def test_session_by_id_dispatches_to_correct_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/account/sessions/<string:session_id>")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is AccountSessionByIdApi
|
||||
assert "DELETE" in rule.methods
|
||||
|
||||
|
||||
def test_subject_match_for_account_filters_by_account_id():
|
||||
"""Account subject scopes queries via account_id."""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
aid = _uuid.uuid4()
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=aid,
|
||||
scopes=frozenset({"full"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
# One predicate, on account_id
|
||||
assert len(clauses) == 1
|
||||
assert "account_id" in str(clauses[0])
|
||||
|
||||
|
||||
def test_subject_match_for_external_sso_filters_by_email_and_issuer():
|
||||
"""External SSO subject scopes via (subject_email, subject_issuer)
|
||||
AND account_id IS NULL — so a same-email account row from a
|
||||
federated tenant cannot be revoked through an SSO bearer.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
from controllers.openapi.account import _subject_match
|
||||
from libs.oauth_bearer import AuthContext, SubjectType
|
||||
|
||||
ctx = AuthContext(
|
||||
subject_type=SubjectType.EXTERNAL_SSO,
|
||||
subject_email="sso@partner.com",
|
||||
subject_issuer="https://idp.partner.com",
|
||||
account_id=None,
|
||||
scopes=frozenset({"apps:run"}),
|
||||
token_id=_uuid.uuid4(),
|
||||
source="oauth_external_sso",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
clauses = _subject_match(ctx)
|
||||
assert len(clauses) == 3
|
||||
rendered = " ".join(str(c) for c in clauses)
|
||||
assert "subject_email" in rendered
|
||||
assert "subject_issuer" in rendered
|
||||
assert "account_id IS NULL" in rendered
|
||||
@ -0,0 +1,48 @@
|
||||
"""Unit tests for AppDescribeQuery (`?fields=` allow-list)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps import AppDescribeQuery
|
||||
|
||||
|
||||
def test_no_fields_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_empty_string_returns_none() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": ""})
|
||||
assert q.fields is None
|
||||
|
||||
|
||||
def test_single_field() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info"})
|
||||
assert q.fields == {"info"}
|
||||
|
||||
|
||||
def test_comma_list() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": "info,parameters"})
|
||||
assert q.fields == {"info", "parameters"}
|
||||
|
||||
|
||||
def test_whitespace_tolerant() -> None:
|
||||
q = AppDescribeQuery.model_validate({"fields": " info , input_schema "})
|
||||
assert q.fields == {"info", "input_schema"}
|
||||
|
||||
|
||||
def test_unknown_member_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "garbage"})
|
||||
|
||||
|
||||
def test_unknown_among_known_rejected() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info,garbage"})
|
||||
|
||||
|
||||
def test_extra_param_forbidden() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
AppDescribeQuery.model_validate({"fields": "info", "page": "1"})
|
||||
105
api/tests/unit_tests/controllers/openapi/test_app_list_query.py
Normal file
105
api/tests/unit_tests/controllers/openapi/test_app_list_query.py
Normal file
@ -0,0 +1,105 @@
|
||||
"""Unit tests for AppListQuery — the /apps query-param validator.
|
||||
|
||||
Runs against the model directly, not the HTTP layer. Pins:
|
||||
- defaults match the plan (page=1, limit=20).
|
||||
- workspace_id is required.
|
||||
- numeric bounds enforced (page >= 1, limit in [1, MAX_PAGE_LIMIT]).
|
||||
- mode validates against the AppMode enum.
|
||||
- name and tag have length caps.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT
|
||||
from controllers.openapi.apps import AppListQuery
|
||||
|
||||
|
||||
def test_defaults():
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
assert q.tag is None
|
||||
|
||||
|
||||
def test_workspace_id_required():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({})
|
||||
|
||||
|
||||
def test_page_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": -1})
|
||||
|
||||
|
||||
def test_page_rejects_non_integer_string():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "page": "abc"})
|
||||
|
||||
|
||||
def test_limit_must_be_positive():
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": 0})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": -1})
|
||||
|
||||
|
||||
def test_limit_caps_at_max_page_limit():
|
||||
# Boundary accepts.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT})
|
||||
assert q.limit == MAX_PAGE_LIMIT
|
||||
|
||||
# Just over rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "limit": MAX_PAGE_LIMIT + 1})
|
||||
|
||||
|
||||
def test_mode_whitelisted_against_app_mode():
|
||||
# Valid mode passes.
|
||||
q = AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
|
||||
# Invalid mode rejects.
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_name_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 200})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "name": "x" * 201})
|
||||
|
||||
|
||||
def test_tag_length_capped():
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 100})
|
||||
with pytest.raises(ValidationError):
|
||||
AppListQuery.model_validate({"workspace_id": "ws-1", "tag": "x" * 101})
|
||||
|
||||
|
||||
def test_all_fields_accept_valid_values():
|
||||
"""Pin the happy-path acceptance for every field in one place."""
|
||||
q = AppListQuery.model_validate(
|
||||
{
|
||||
"workspace_id": "ws-1",
|
||||
"page": 5,
|
||||
"limit": 50,
|
||||
"mode": "workflow",
|
||||
"name": "search",
|
||||
"tag": "prod",
|
||||
}
|
||||
)
|
||||
assert q.workspace_id == "ws-1"
|
||||
assert q.page == 5
|
||||
assert q.limit == 50
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "workflow"
|
||||
assert q.name == "search"
|
||||
assert q.tag == "prod"
|
||||
@ -0,0 +1,55 @@
|
||||
"""Unit tests for app payload-rendering helpers — independent of
|
||||
HTTP plumbing or DB. Pin the response shapes that are CLI contracts.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi.apps import ( # pyright: ignore[reportPrivateUsage]
|
||||
_EMPTY_PARAMETERS,
|
||||
parameters_payload,
|
||||
)
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
|
||||
|
||||
def _fake_app(**overrides):
|
||||
base = {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author_name": "alice",
|
||||
"tags": [SimpleNamespace(name="prod")],
|
||||
"updated_at": None,
|
||||
"enable_api": True,
|
||||
"workflow": None,
|
||||
"app_model_config": None,
|
||||
}
|
||||
base.update(overrides)
|
||||
return SimpleNamespace(**base)
|
||||
|
||||
|
||||
def test_parameters_payload_raises_app_unavailable_when_no_config():
|
||||
with pytest.raises(AppUnavailableError):
|
||||
parameters_payload(_fake_app(mode="chat", app_model_config=None))
|
||||
|
||||
|
||||
def test_empty_parameters_constant_matches_describe_fallback_shape():
|
||||
"""The fallback dict served by /describe when an app has no config
|
||||
must match the spec's stated keys (opening_statement, suggested_questions,
|
||||
user_input_form, file_upload, system_parameters)."""
|
||||
assert set(_EMPTY_PARAMETERS.keys()) == {
|
||||
"opening_statement",
|
||||
"suggested_questions",
|
||||
"user_input_form",
|
||||
"file_upload",
|
||||
"system_parameters",
|
||||
}
|
||||
assert _EMPTY_PARAMETERS["suggested_questions"] == []
|
||||
assert _EMPTY_PARAMETERS["user_input_form"] == []
|
||||
assert _EMPTY_PARAMETERS["opening_statement"] is None
|
||||
assert _EMPTY_PARAMETERS["file_upload"] is None
|
||||
assert _EMPTY_PARAMETERS["system_parameters"] == {}
|
||||
@ -0,0 +1,45 @@
|
||||
import pytest
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
from controllers.openapi.app_run import (
|
||||
_DISPATCH,
|
||||
AppRunRequest,
|
||||
_unpack_blocking,
|
||||
)
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def test_dispatch_covers_runnable_modes():
|
||||
runnable = {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.COMPLETION, AppMode.WORKFLOW}
|
||||
assert set(_DISPATCH) == runnable
|
||||
|
||||
|
||||
def test_unpack_blocking_passes_through_mapping():
|
||||
assert _unpack_blocking({"a": 1}) == {"a": 1}
|
||||
|
||||
|
||||
def test_unpack_blocking_unwraps_tuple():
|
||||
assert _unpack_blocking(({"a": 1}, 200)) == {"a": 1}
|
||||
|
||||
|
||||
def test_unpack_blocking_rejects_non_mapping():
|
||||
with pytest.raises(InternalServerError):
|
||||
_unpack_blocking("not a mapping")
|
||||
|
||||
|
||||
def test_app_run_request_strips_blank_conversation_id():
|
||||
payload = AppRunRequest(inputs={}, conversation_id=" ")
|
||||
assert payload.conversation_id is None
|
||||
|
||||
|
||||
def test_app_run_request_rejects_invalid_uuid_conversation_id():
|
||||
from pydantic import ValidationError
|
||||
with pytest.raises(ValidationError, match="conversation_id must be a valid UUID"):
|
||||
AppRunRequest(inputs={}, conversation_id="not-a-uuid")
|
||||
|
||||
|
||||
def test_app_run_request_accepts_valid_uuid_conversation_id():
|
||||
import uuid as _uuid
|
||||
cid = str(_uuid.uuid4())
|
||||
payload = AppRunRequest(inputs={}, conversation_id=cid)
|
||||
assert payload.conversation_id == cid
|
||||
@ -0,0 +1,51 @@
|
||||
"""Unit tests for AppPermittedListQuery — the /apps/permitted query validator.
|
||||
|
||||
Strict ConfigDict(extra='forbid'): cross-tenant tag/workspace_id are
|
||||
unresolvable, so the model must reject them as 422 instead of silently
|
||||
dropping them. Mode/name/page/limit have the same shape as AppListQuery.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from controllers.openapi.apps_permitted import AppPermittedListQuery
|
||||
|
||||
|
||||
def test_query_defaults_match_apps_list():
|
||||
q = AppPermittedListQuery.model_validate({})
|
||||
assert q.page == 1
|
||||
assert q.limit == 20
|
||||
assert q.mode is None
|
||||
assert q.name is None
|
||||
|
||||
|
||||
def test_query_rejects_workspace_id():
|
||||
"""workspace_id is meaningless for /permitted (cross-tenant); rejecting it
|
||||
forces CLI authors to drop the param rather than send it silently."""
|
||||
with pytest.raises(ValidationError):
|
||||
AppPermittedListQuery.model_validate({"workspace_id": "ws-1"})
|
||||
|
||||
|
||||
def test_query_rejects_tag():
|
||||
"""Tags are tenant-scoped; cross-tenant tag resolution is undefined."""
|
||||
with pytest.raises(ValidationError):
|
||||
AppPermittedListQuery.model_validate({"tag": "prod"})
|
||||
|
||||
|
||||
def test_query_validates_mode_against_app_mode():
|
||||
with pytest.raises(ValidationError):
|
||||
AppPermittedListQuery.model_validate({"mode": "not-a-mode"})
|
||||
|
||||
|
||||
def test_query_clamps_limit_at_max():
|
||||
with pytest.raises(ValidationError):
|
||||
AppPermittedListQuery.model_validate({"limit": 500})
|
||||
|
||||
|
||||
def test_query_accepts_valid_mode():
|
||||
"""Pin the happy path: AppMode values pass."""
|
||||
q = AppPermittedListQuery.model_validate({"mode": "chat"})
|
||||
assert q.mode is not None
|
||||
assert q.mode.value == "chat"
|
||||
@ -0,0 +1,19 @@
|
||||
import logging
|
||||
|
||||
from controllers.openapi._audit import EVENT_APP_RUN_OPENAPI, emit_app_run
|
||||
|
||||
|
||||
def test_event_constant():
|
||||
assert EVENT_APP_RUN_OPENAPI == "app.run.openapi"
|
||||
|
||||
|
||||
def test_emit_app_run_logs_with_audit_extra(caplog):
|
||||
with caplog.at_level(logging.INFO, logger="controllers.openapi._audit"):
|
||||
emit_app_run(app_id="app1", tenant_id="t1", caller_kind="account", mode="chat")
|
||||
record = next(r for r in caplog.records if r.message and "app.run.openapi" in r.message)
|
||||
assert record.audit is True
|
||||
assert record.event == EVENT_APP_RUN_OPENAPI
|
||||
assert record.app_id == "app1"
|
||||
assert record.tenant_id == "t1"
|
||||
assert record.caller_kind == "account"
|
||||
assert record.mode == "chat"
|
||||
127
api/tests/unit_tests/controllers/openapi/test_cors.py
Normal file
127
api/tests/unit_tests/controllers/openapi/test_cors.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""CORS posture for /openapi/v1/* — default empty allowlist (same-origin),
|
||||
expandable via OPENAPI_CORS_ALLOW_ORIGINS. Cross-origin requests from
|
||||
disallowed origins do not receive the Access-Control-Allow-Origin
|
||||
header, which the browser then blocks.
|
||||
|
||||
Tests use a fresh Blueprint + Flask-CORS per case because the production
|
||||
blueprint is a module-level singleton and can't be reconfigured once
|
||||
registered.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
from flask import Blueprint, Flask
|
||||
from flask.views import MethodView
|
||||
from flask_cors import CORS
|
||||
from flask_restx import Resource
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_blueprints import OPENAPI_HEADERS, OPENAPI_MAX_AGE_SECONDS
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _make_app(allowed_origins: list[str], blueprint_name: str) -> Flask:
|
||||
"""Build a Flask app with a fresh openapi-style blueprint mirroring
|
||||
production CORS settings, parameterised on the origin allowlist.
|
||||
"""
|
||||
bp = Blueprint(blueprint_name, __name__, url_prefix="/openapi/v1")
|
||||
api = ExternalApi(bp, version="1.0", title="OpenAPI Test", description="")
|
||||
|
||||
@api.route("/_health")
|
||||
class _Health(Resource):
|
||||
def get(self):
|
||||
return {"ok": True}
|
||||
|
||||
CORS(
|
||||
bp,
|
||||
resources={r"/*": {"origins": allowed_origins}},
|
||||
supports_credentials=True,
|
||||
allow_headers=list(OPENAPI_HEADERS),
|
||||
methods=["GET", "POST", "PATCH", "DELETE", "OPTIONS"],
|
||||
expose_headers=["X-Version"],
|
||||
max_age=OPENAPI_MAX_AGE_SECONDS,
|
||||
)
|
||||
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_default_openapi_cors_allowlist_is_empty():
|
||||
"""Default config admits no cross-origin until operator opts in."""
|
||||
assert dify_config.OPENAPI_CORS_ALLOW_ORIGINS == []
|
||||
|
||||
|
||||
def test_preflight_allowed_origin_returns_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t1")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert response.headers.get("Access-Control-Allow-Origin") == "https://app.example.com"
|
||||
assert response.headers.get("Access-Control-Max-Age") == str(OPENAPI_MAX_AGE_SECONDS)
|
||||
|
||||
|
||||
def test_preflight_disallowed_origin_omits_cors_headers():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t2")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://attacker.example",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
# flask-cors omits Allow-Origin for disallowed origins; browser blocks.
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_preflight_with_default_empty_allowlist_omits_cors_headers():
|
||||
app = _make_app([], "openapi_t3")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert "Access-Control-Allow-Origin" not in response.headers
|
||||
|
||||
|
||||
def test_same_origin_request_succeeds_without_origin_header():
|
||||
app = _make_app(["https://app.example.com"], "openapi_t4")
|
||||
client = app.test_client()
|
||||
# Browsers don't send Origin on same-origin GETs.
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_authorization_header_is_in_allow_headers():
|
||||
"""Bearer-authed routes need Authorization in the preflight response."""
|
||||
app = _make_app(["https://app.example.com"], "openapi_t5")
|
||||
client = app.test_client()
|
||||
response = client.options(
|
||||
"/openapi/v1/_health",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
"Access-Control-Request-Headers": "Authorization",
|
||||
},
|
||||
)
|
||||
|
||||
allow_headers = response.headers.get("Access-Control-Allow-Headers", "").lower()
|
||||
assert "authorization" in allow_headers
|
||||
@ -0,0 +1,52 @@
|
||||
"""Account-branch device-flow approve/deny under /openapi/v1."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import DeviceApproveApi, DeviceDenyApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_approve_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve" in rules
|
||||
|
||||
|
||||
def test_deny_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/deny" in rules
|
||||
|
||||
|
||||
def test_approve_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceApproveApi
|
||||
|
||||
|
||||
def test_deny_dispatches_to_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is DeviceDenyApi
|
||||
|
||||
|
||||
def test_approve_and_deny_methods(openapi_app: Flask):
|
||||
approve = _rule(openapi_app, "/openapi/v1/oauth/device/approve")
|
||||
deny = _rule(openapi_app, "/openapi/v1/oauth/device/deny")
|
||||
assert "POST" in approve.methods
|
||||
assert "POST" in deny.methods
|
||||
47
api/tests/unit_tests/controllers/openapi/test_device_code.py
Normal file
47
api/tests/unit_tests/controllers/openapi/test_device_code.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""POST /openapi/v1/oauth/device/code is the canonical RFC 8628 device
|
||||
authorization endpoint.
|
||||
|
||||
Tests verify URL routing without invoking the handler — invoking would
|
||||
require Redis, which the unit-test runtime does not initialise.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceCodeApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/code" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceCodeApi
|
||||
|
||||
|
||||
def test_route_accepts_post(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/code")
|
||||
assert "POST" in rule.methods
|
||||
|
||||
|
||||
def test_known_client_ids_default_includes_difyctl():
|
||||
from configs import dify_config
|
||||
|
||||
assert "difyctl" in dify_config.OPENAPI_KNOWN_CLIENT_IDS
|
||||
@ -0,0 +1,36 @@
|
||||
"""GET /openapi/v1/oauth/device/lookup is the canonical user-code lookup."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceLookupApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/lookup" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceLookupApi
|
||||
|
||||
|
||||
def test_route_accepts_get(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/lookup")
|
||||
assert "GET" in rule.methods
|
||||
79
api/tests/unit_tests/controllers/openapi/test_device_sso.py
Normal file
79
api/tests/unit_tests/controllers/openapi/test_device_sso.py
Normal file
@ -0,0 +1,79 @@
|
||||
"""SSO-branch device-flow endpoints under /openapi/v1/oauth/device/."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device_sso import (
|
||||
approval_context,
|
||||
approve_external,
|
||||
sso_complete,
|
||||
sso_initiate,
|
||||
)
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_sso_initiate_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-initiate" in rules
|
||||
|
||||
|
||||
def test_sso_complete_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/sso-complete" in rules
|
||||
|
||||
|
||||
def test_approval_context_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approval-context" in rules
|
||||
|
||||
|
||||
def test_approve_external_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/approve-external" in rules
|
||||
|
||||
|
||||
def test_sso_initiate_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-initiate")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_initiate
|
||||
|
||||
|
||||
def test_sso_complete_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/sso-complete")
|
||||
assert openapi_app.view_functions[rule.endpoint] is sso_complete
|
||||
|
||||
|
||||
def test_approval_context_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approval-context")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approval_context
|
||||
|
||||
|
||||
def test_approve_external_dispatches_to_function(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/oauth/device/approve-external")
|
||||
assert openapi_app.view_functions[rule.endpoint] is approve_external
|
||||
|
||||
|
||||
def test_sso_complete_idp_callback_url_uses_canonical_path():
|
||||
"""sso_initiate hardcodes the IdP callback URL — must point at the
|
||||
canonical /openapi/v1/ path so IdP-side ACS configuration matches.
|
||||
"""
|
||||
from controllers.openapi import oauth_device_sso
|
||||
|
||||
assert oauth_device_sso._SSO_COMPLETE_PATH == "/openapi/v1/oauth/device/sso-complete"
|
||||
@ -0,0 +1,31 @@
|
||||
"""POST /openapi/v1/oauth/device/token is the canonical poll endpoint."""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.oauth_device import OAuthDeviceTokenApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_openapi_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/oauth/device/token" in rules
|
||||
|
||||
|
||||
def test_route_dispatches_to_class(openapi_app: Flask):
|
||||
rule = next(r for r in openapi_app.url_map.iter_rules() if r.rule == "/openapi/v1/oauth/device/token")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is OAuthDeviceTokenApi
|
||||
33
api/tests/unit_tests/controllers/openapi/test_health.py
Normal file
33
api/tests/unit_tests/controllers/openapi/test_health.py
Normal file
@ -0,0 +1,33 @@
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def test_health_returns_ok(app: Flask):
|
||||
client = app.test_client()
|
||||
response = client.get("/openapi/v1/_health")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.get_json() == {"ok": True}
|
||||
|
||||
|
||||
def test_health_path_is_under_openapi_v1_prefix(app: Flask):
|
||||
client = app.test_client()
|
||||
assert client.get("/_health").status_code == 404
|
||||
assert client.get("/v1/_health").status_code == 404
|
||||
assert client.get("/openapi/v1/_health").status_code == 200
|
||||
182
api/tests/unit_tests/controllers/openapi/test_input_schema.py
Normal file
182
api/tests/unit_tests/controllers/openapi/test_input_schema.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""Unit tests for input_schema derivation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from controllers.openapi._input_schema import _form_to_jsonschema
|
||||
|
||||
|
||||
def _wrap(component: dict) -> list[dict]:
|
||||
"""user_input_form rows are single-key dicts: {"text-input": {...}}."""
|
||||
return [component]
|
||||
|
||||
|
||||
def test_text_input_required() -> None:
|
||||
form = _wrap({"text-input": {"variable": "industry", "label": "Industry", "required": True, "max_length": 200}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"industry": {"type": "string", "title": "Industry", "maxLength": 200}}
|
||||
assert required == ["industry"]
|
||||
|
||||
|
||||
def test_paragraph_optional() -> None:
|
||||
form = _wrap({"paragraph": {"variable": "context", "label": "Context", "required": False, "max_length": 4000}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["context"] == {"type": "string", "title": "Context", "maxLength": 4000}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_select_enum() -> None:
|
||||
form = _wrap(
|
||||
{
|
||||
"select": {
|
||||
"variable": "tier",
|
||||
"label": "Tier",
|
||||
"required": True,
|
||||
"options": ["free", "pro", "enterprise"],
|
||||
}
|
||||
}
|
||||
)
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {"tier": {"type": "string", "title": "Tier", "enum": ["free", "pro", "enterprise"]}}
|
||||
assert required == ["tier"]
|
||||
|
||||
|
||||
def test_number() -> None:
|
||||
form = _wrap({"number": {"variable": "count", "label": "Count", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["count"] == {"type": "number", "title": "Count"}
|
||||
|
||||
|
||||
def test_file() -> None:
|
||||
form = _wrap({"file": {"variable": "doc", "label": "Doc", "required": True}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props["doc"]["type"] == "object"
|
||||
assert "title" in props["doc"]
|
||||
assert required == ["doc"]
|
||||
|
||||
|
||||
def test_file_list() -> None:
|
||||
form = _wrap({"file-list": {"variable": "attachments", "label": "Attachments", "required": False}})
|
||||
props, _required = _form_to_jsonschema(form)
|
||||
assert props["attachments"]["type"] == "array"
|
||||
assert props["attachments"]["items"]["type"] == "object"
|
||||
|
||||
|
||||
def test_unknown_type_skipped() -> None:
|
||||
"""Forward-compat: unknown variable types are skipped, not 500'd."""
|
||||
form = _wrap({"future-type": {"variable": "x", "label": "X", "required": False}})
|
||||
props, required = _form_to_jsonschema(form)
|
||||
assert props == {}
|
||||
assert required == []
|
||||
|
||||
|
||||
def test_required_order_preserved() -> None:
|
||||
form = [
|
||||
{"text-input": {"variable": "a", "label": "A", "required": True}},
|
||||
{"text-input": {"variable": "b", "label": "B", "required": False}},
|
||||
{"text-input": {"variable": "c", "label": "C", "required": True}},
|
||||
]
|
||||
_props, required = _form_to_jsonschema(form)
|
||||
assert required == ["a", "c"]
|
||||
|
||||
|
||||
def test_max_length_omitted_when_zero() -> None:
|
||||
form = _wrap({"text-input": {"variable": "x", "label": "X", "required": False, "max_length": 0}})
|
||||
props, _ = _form_to_jsonschema(form)
|
||||
assert "maxLength" not in props["x"]
|
||||
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from controllers.openapi._input_schema import EMPTY_INPUT_SCHEMA, build_input_schema
|
||||
from controllers.service_api.app.error import AppUnavailableError
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
def _stub_app(mode: AppMode, *, form: list[dict] | None = None, has_workflow: bool | None = None):
|
||||
"""Returns a MagicMock whose .mode + workflow / app_model_config branch is wired up."""
|
||||
app = MagicMock()
|
||||
app.mode = mode
|
||||
if mode in (AppMode.WORKFLOW, AppMode.ADVANCED_CHAT):
|
||||
if has_workflow is False:
|
||||
app.workflow = None
|
||||
else:
|
||||
app.workflow = MagicMock()
|
||||
app.workflow.user_input_form.return_value = form or []
|
||||
app.workflow.features_dict = {}
|
||||
else:
|
||||
if has_workflow is False:
|
||||
app.app_model_config = None
|
||||
else:
|
||||
app.app_model_config = MagicMock()
|
||||
app.app_model_config.to_dict.return_value = {"user_input_form": form or []}
|
||||
return app
|
||||
|
||||
|
||||
def test_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.CHAT, form=[{"text-input": {"variable": "x", "label": "X", "required": True}}])
|
||||
schema = build_input_schema(app)
|
||||
assert schema["$schema"] == "https://json-schema.org/draft/2020-12/schema"
|
||||
assert "query" in schema["properties"]
|
||||
assert schema["properties"]["query"]["type"] == "string"
|
||||
assert schema["properties"]["query"]["minLength"] == 1
|
||||
assert "query" in schema["required"]
|
||||
assert "inputs" in schema["required"]
|
||||
assert schema["properties"]["inputs"]["additionalProperties"] is False
|
||||
|
||||
|
||||
def test_agent_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.AGENT_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_advanced_chat_mode_includes_query() -> None:
|
||||
app = _stub_app(AppMode.ADVANCED_CHAT, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" in schema["properties"]
|
||||
|
||||
|
||||
def test_workflow_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_completion_mode_omits_query() -> None:
|
||||
app = _stub_app(AppMode.COMPLETION, form=[])
|
||||
schema = build_input_schema(app)
|
||||
assert "query" not in schema["properties"]
|
||||
assert schema["required"] == ["inputs"]
|
||||
|
||||
|
||||
def test_inputs_required_driven_by_form() -> None:
|
||||
app = _stub_app(
|
||||
AppMode.CHAT,
|
||||
form=[
|
||||
{"text-input": {"variable": "industry", "label": "Industry", "required": True}},
|
||||
{"text-input": {"variable": "context", "label": "Context", "required": False}},
|
||||
],
|
||||
)
|
||||
schema = build_input_schema(app)
|
||||
assert schema["properties"]["inputs"]["required"] == ["industry"]
|
||||
|
||||
|
||||
def test_misconfigured_chat_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.CHAT, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_misconfigured_workflow_raises_app_unavailable() -> None:
|
||||
app = _stub_app(AppMode.WORKFLOW, has_workflow=False)
|
||||
with pytest.raises(AppUnavailableError):
|
||||
build_input_schema(app)
|
||||
|
||||
|
||||
def test_empty_input_schema_sentinel_shape() -> None:
|
||||
assert EMPTY_INPUT_SCHEMA["type"] == "object"
|
||||
assert EMPTY_INPUT_SCHEMA["properties"] == {}
|
||||
assert EMPTY_INPUT_SCHEMA["required"] == []
|
||||
31
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
31
api/tests/unit_tests/controllers/openapi/test_models.py
Normal file
@ -0,0 +1,31 @@
|
||||
from controllers.openapi._models import MessageMetadata, UsageInfo
|
||||
|
||||
|
||||
def test_usage_info_defaults_zero():
|
||||
u = UsageInfo()
|
||||
assert u.prompt_tokens == 0
|
||||
assert u.completion_tokens == 0
|
||||
assert u.total_tokens == 0
|
||||
|
||||
|
||||
def test_message_metadata_accepts_partial():
|
||||
m = MessageMetadata(usage=UsageInfo(total_tokens=10))
|
||||
assert m.usage.total_tokens == 10
|
||||
assert m.retriever_resources == []
|
||||
|
||||
|
||||
def test_describe_response_all_blocks_optional() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
payload = AppDescribeResponse().model_dump(mode="json", exclude_none=False)
|
||||
assert payload == {"info": None, "parameters": None, "input_schema": None}
|
||||
|
||||
|
||||
def test_describe_response_input_schema_field() -> None:
|
||||
from controllers.openapi._models import AppDescribeResponse
|
||||
|
||||
schema = {"$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object"}
|
||||
payload = AppDescribeResponse(input_schema=schema).model_dump(mode="json", exclude_none=False)
|
||||
assert payload["input_schema"] == schema
|
||||
assert payload["info"] is None
|
||||
assert payload["parameters"] is None
|
||||
@ -0,0 +1,124 @@
|
||||
"""Unit tests for PaginationEnvelope generic Pydantic model."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from controllers.openapi._models import PaginationEnvelope
|
||||
|
||||
|
||||
class _Row(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
def test_envelope_basic_fields():
|
||||
env = PaginationEnvelope[_Row](page=1, limit=20, total=42, has_more=True, data=[_Row(id="a", name="A")])
|
||||
dumped = env.model_dump(mode="json")
|
||||
assert dumped == {
|
||||
"page": 1,
|
||||
"limit": 20,
|
||||
"total": 42,
|
||||
"has_more": True,
|
||||
"data": [{"id": "a", "name": "A"}],
|
||||
}
|
||||
|
||||
|
||||
def test_envelope_empty_data_no_more():
|
||||
env = PaginationEnvelope[_Row](page=1, limit=20, total=0, has_more=False, data=[])
|
||||
assert env.model_dump(mode="json")["data"] == []
|
||||
assert env.model_dump(mode="json")["has_more"] is False
|
||||
|
||||
|
||||
def test_envelope_has_more_true_when_total_exceeds_page_window():
|
||||
env = PaginationEnvelope[_Row].build(page=1, limit=20, total=42, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is True
|
||||
|
||||
|
||||
def test_envelope_has_more_false_when_total_within_page_window():
|
||||
env = PaginationEnvelope[_Row].build(page=2, limit=20, total=22, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is False
|
||||
|
||||
|
||||
def test_envelope_has_more_false_for_last_page():
|
||||
env = PaginationEnvelope[_Row].build(page=3, limit=20, total=42, items=[_Row(id="a", name="A")])
|
||||
assert env.has_more is False
|
||||
|
||||
|
||||
def test_max_page_limit_is_200():
|
||||
from controllers.openapi._models import MAX_PAGE_LIMIT
|
||||
|
||||
assert MAX_PAGE_LIMIT == 200
|
||||
|
||||
|
||||
def test_envelope_uses_pep695_generics():
|
||||
"""Verify the class uses PEP 695 native generic syntax (not legacy Generic[T])."""
|
||||
from controllers.openapi._models import PaginationEnvelope
|
||||
|
||||
# PEP 695 syntax populates __type_params__; the legacy Generic[T] form does not.
|
||||
assert PaginationEnvelope.__type_params__, "expected PEP 695 native generic syntax"
|
||||
|
||||
fields = PaginationEnvelope.model_fields
|
||||
assert {"page", "limit", "total", "has_more", "data"} <= set(fields)
|
||||
|
||||
|
||||
def test_app_info_response_dump_matches_spec():
|
||||
from controllers.openapi._models import AppInfoResponse
|
||||
|
||||
obj = AppInfoResponse(
|
||||
id="app1",
|
||||
name="X",
|
||||
description="d",
|
||||
mode="chat",
|
||||
author="alice",
|
||||
tags=[{"name": "prod"}],
|
||||
)
|
||||
assert obj.model_dump(mode="json") == {
|
||||
"id": "app1",
|
||||
"name": "X",
|
||||
"description": "d",
|
||||
"mode": "chat",
|
||||
"author": "alice",
|
||||
"tags": [{"name": "prod"}],
|
||||
}
|
||||
|
||||
|
||||
def test_app_describe_response_nests_info_and_parameters():
|
||||
from controllers.openapi._models import AppDescribeInfo, AppDescribeResponse
|
||||
|
||||
info = AppDescribeInfo(
|
||||
id="app1",
|
||||
name="X",
|
||||
mode="chat",
|
||||
description=None,
|
||||
tags=[],
|
||||
author=None,
|
||||
updated_at="2026-05-05T00:00:00+00:00",
|
||||
service_api_enabled=True,
|
||||
)
|
||||
obj = AppDescribeResponse(info=info, parameters={"opening_statement": None})
|
||||
dumped = obj.model_dump(mode="json")
|
||||
assert dumped["info"]["service_api_enabled"] is True
|
||||
assert dumped["parameters"]["opening_statement"] is None
|
||||
|
||||
|
||||
def test_response_models_dump_per_mode():
|
||||
from controllers.openapi._models import (
|
||||
ChatMessageResponse, CompletionMessageResponse, WorkflowRunResponse, WorkflowRunData,
|
||||
)
|
||||
chat = ChatMessageResponse(
|
||||
event="message", task_id="t1", id="m1", message_id="m1",
|
||||
conversation_id="c1", mode="chat", answer="hi", created_at=0,
|
||||
)
|
||||
assert chat.model_dump(mode="json")["mode"] == "chat"
|
||||
wf = WorkflowRunResponse(
|
||||
workflow_run_id="r1", task_id="t1",
|
||||
data=WorkflowRunData(id="r1", workflow_id="w1", status="succeeded"),
|
||||
)
|
||||
assert wf.model_dump(mode="json")["data"]["status"] == "succeeded"
|
||||
assert wf.model_dump(mode="json")["mode"] == "workflow"
|
||||
comp = CompletionMessageResponse(
|
||||
event="message", task_id="t2", id="m2", message_id="m2",
|
||||
mode="completion", answer="ok", created_at=0,
|
||||
)
|
||||
assert comp.model_dump(mode="json")["mode"] == "completion"
|
||||
58
api/tests/unit_tests/controllers/openapi/test_workspaces.py
Normal file
58
api/tests/unit_tests/controllers/openapi/test_workspaces.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Phase E step 17: workspace reads at /openapi/v1/workspaces. Bearer-authed
|
||||
list + member-gated detail. No legacy /v1/ equivalent — the cookie-authed
|
||||
/console/api/workspaces is a separate consumer that stays in console.
|
||||
"""
|
||||
|
||||
import builtins
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask.views import MethodView
|
||||
|
||||
from controllers.openapi import bp as openapi_bp
|
||||
from controllers.openapi.workspaces import WorkspaceByIdApi, WorkspacesApi
|
||||
|
||||
if not hasattr(builtins, "MethodView"):
|
||||
builtins.MethodView = MethodView # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openapi_app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
app.register_blueprint(openapi_bp)
|
||||
return app
|
||||
|
||||
|
||||
def _rule(app: Flask, path: str):
|
||||
return next(r for r in app.url_map.iter_rules() if r.rule == path)
|
||||
|
||||
|
||||
def test_workspaces_list_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/workspaces" in rules
|
||||
|
||||
|
||||
def test_workspaces_list_dispatches_to_workspaces_api(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/workspaces")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is WorkspacesApi
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_workspace_by_id_route_registered(openapi_app: Flask):
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/openapi/v1/workspaces/<string:workspace_id>" in rules
|
||||
|
||||
|
||||
def test_workspace_by_id_dispatches_to_correct_class(openapi_app: Flask):
|
||||
rule = _rule(openapi_app, "/openapi/v1/workspaces/<string:workspace_id>")
|
||||
assert openapi_app.view_functions[rule.endpoint].view_class is WorkspaceByIdApi
|
||||
assert "GET" in rule.methods
|
||||
|
||||
|
||||
def test_console_legacy_workspaces_route_not_remounted_on_openapi(openapi_app: Flask):
|
||||
"""Phase E only adds the bearer-authed mounts on /openapi/v1/.
|
||||
The cookie-authed /console/api/workspaces stays where it is.
|
||||
"""
|
||||
rules = {r.rule for r in openapi_app.url_map.iter_rules()}
|
||||
assert "/console/api/workspaces" not in rules
|
||||
9
api/tests/unit_tests/core/app/test_invoke_from.py
Normal file
9
api/tests/unit_tests/core/app/test_invoke_from.py
Normal file
@ -0,0 +1,9 @@
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
|
||||
|
||||
def test_openapi_variant_present():
|
||||
assert InvokeFrom.OPENAPI.value == "openapi"
|
||||
|
||||
|
||||
def test_openapi_distinct_from_service_api():
|
||||
assert InvokeFrom.OPENAPI != InvokeFrom.SERVICE_API
|
||||
29
api/tests/unit_tests/libs/test_oauth_bearer.py
Normal file
29
api/tests/unit_tests/libs/test_oauth_bearer.py
Normal file
@ -0,0 +1,29 @@
|
||||
"""Unit tests for the openapi bearer-scope catalog and TokenKind registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
def test_apps_read_permitted_scope_present():
|
||||
from libs.oauth_bearer import Scope
|
||||
|
||||
assert Scope.APPS_READ_PERMITTED.value == "apps:read:permitted"
|
||||
|
||||
|
||||
def test_dfoe_token_kind_carries_apps_read_permitted():
|
||||
from libs.oauth_bearer import Scope, build_registry
|
||||
|
||||
registry = build_registry(MagicMock(), MagicMock())
|
||||
dfoe = next(k for k in registry.kinds() if k.prefix == "dfoe_")
|
||||
assert Scope.APPS_READ_PERMITTED in dfoe.scopes
|
||||
|
||||
|
||||
def test_dfoa_token_kind_does_not_carry_apps_read_permitted():
|
||||
"""dfoa_ relies on Scope.FULL umbrella; the explicit permitted scope
|
||||
is reserved for dfoe_."""
|
||||
from libs.oauth_bearer import Scope, build_registry
|
||||
|
||||
registry = build_registry(MagicMock(), MagicMock())
|
||||
dfoa = next(k for k in registry.kinds() if k.prefix == "dfoa_")
|
||||
assert Scope.APPS_READ_PERMITTED not in dfoa.scopes
|
||||
94
api/tests/unit_tests/libs/test_oauth_bearer_layer0_cache.py
Normal file
94
api/tests/unit_tests/libs/test_oauth_bearer_layer0_cache.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Unit tests for record_layer0_verdict — merge L0 verdict into AuthContext cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from libs.oauth_bearer import record_layer0_verdict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_no_op_when_cache_entry_missing(mock_redis):
|
||||
mock_redis.get.return_value = None
|
||||
record_layer0_verdict("h1", "t1", True)
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_no_op_when_cache_entry_invalid_marker(mock_redis):
|
||||
mock_redis.get.return_value = b"invalid"
|
||||
record_layer0_verdict("h1", "t1", True)
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_no_op_when_json_malformed(mock_redis):
|
||||
mock_redis.get.return_value = b"not json"
|
||||
record_layer0_verdict("h1", "t1", True)
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_no_op_when_ttl_expired(mock_redis):
|
||||
mock_redis.get.return_value = json.dumps(
|
||||
{
|
||||
"subject_email": "e",
|
||||
"subject_issuer": None,
|
||||
"account_id": None,
|
||||
"token_id": "tid",
|
||||
"expires_at": None,
|
||||
}
|
||||
).encode()
|
||||
mock_redis.ttl.return_value = -1
|
||||
record_layer0_verdict("h1", "t1", True)
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_merges_new_tenant_verdict(mock_redis):
|
||||
mock_redis.get.return_value = json.dumps(
|
||||
{
|
||||
"subject_email": "e",
|
||||
"subject_issuer": None,
|
||||
"account_id": None,
|
||||
"token_id": "tid",
|
||||
"expires_at": None,
|
||||
"verified_tenants": {"t0": True},
|
||||
}
|
||||
).encode()
|
||||
mock_redis.ttl.return_value = 42
|
||||
|
||||
record_layer0_verdict("h1", "t1", False)
|
||||
|
||||
mock_redis.setex.assert_called_once()
|
||||
args = mock_redis.setex.call_args
|
||||
assert args.args[0] == "auth:token:h1"
|
||||
assert args.args[1] == 42 # remaining TTL preserved
|
||||
written = json.loads(args.args[2])
|
||||
assert written["verified_tenants"] == {"t0": True, "t1": False}
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.redis_client")
|
||||
def test_merges_when_field_absent_from_legacy_entry(mock_redis):
|
||||
"""Backward compat: legacy cache entry without verified_tenants field."""
|
||||
mock_redis.get.return_value = json.dumps(
|
||||
{
|
||||
"subject_email": "e",
|
||||
"subject_issuer": None,
|
||||
"account_id": None,
|
||||
"token_id": "tid",
|
||||
"expires_at": None,
|
||||
}
|
||||
).encode()
|
||||
mock_redis.ttl.return_value = 42
|
||||
record_layer0_verdict("h1", "t1", True)
|
||||
written = json.loads(mock_redis.setex.call_args.args[2])
|
||||
assert written["verified_tenants"] == {"t1": True}
|
||||
84
api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py
Normal file
84
api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py
Normal file
@ -0,0 +1,84 @@
|
||||
"""require_scope is a route-level gate run after validate_bearer.
|
||||
Tests use a fake auth_ctx attached directly to flask.g — no
|
||||
authenticator wiring needed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.oauth_bearer import (
|
||||
AuthContext,
|
||||
Scope,
|
||||
SubjectType,
|
||||
require_scope,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app() -> Flask:
|
||||
app = Flask(__name__)
|
||||
app.config["TESTING"] = True
|
||||
return app
|
||||
|
||||
|
||||
def _ctx(scopes) -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT,
|
||||
subject_email="user@example.com",
|
||||
subject_issuer="dify:account",
|
||||
account_id=uuid.uuid4(),
|
||||
scopes=scopes,
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants={},
|
||||
)
|
||||
|
||||
|
||||
def test_require_scope_allows_when_scope_present(app: Flask):
|
||||
@require_scope("apps:read")
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
g.auth_ctx = _ctx(frozenset({"apps:read"}))
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_require_scope_rejects_when_scope_missing(app: Flask):
|
||||
@require_scope("apps:write")
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
g.auth_ctx = _ctx(frozenset({"apps:read"}))
|
||||
with pytest.raises(Forbidden) as exc:
|
||||
view()
|
||||
assert "insufficient_scope: apps:write" in str(exc.value.description)
|
||||
|
||||
|
||||
def test_require_scope_full_passes_any_check(app: Flask):
|
||||
@require_scope("apps:write")
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
g.auth_ctx = _ctx(frozenset({Scope.FULL}))
|
||||
assert view() == "ok"
|
||||
|
||||
|
||||
def test_require_scope_without_validate_bearer_raises_runtime_error(app: Flask):
|
||||
@require_scope("apps:read")
|
||||
def view():
|
||||
return "ok"
|
||||
|
||||
with app.test_request_context():
|
||||
# No g.auth_ctx — validate_bearer was forgotten
|
||||
with pytest.raises(RuntimeError, match="stack @validate_bearer above @require_scope"):
|
||||
view()
|
||||
74
api/tests/unit_tests/libs/test_rate_limit_bearer.py
Normal file
74
api/tests/unit_tests/libs/test_rate_limit_bearer.py
Normal file
@ -0,0 +1,74 @@
|
||||
"""Unit tests for the per-token bearer rate limit primitive."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
|
||||
from libs.helper import RateLimiter
|
||||
from libs.rate_limit import (
|
||||
LIMIT_BEARER_PER_TOKEN,
|
||||
enforce_bearer_rate_limit,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
def test_limit_bearer_per_token_uses_60_per_minute_default():
|
||||
assert LIMIT_BEARER_PER_TOKEN.limit == 60
|
||||
assert LIMIT_BEARER_PER_TOKEN.window == timedelta(minutes=1)
|
||||
|
||||
|
||||
def test_seconds_until_available_returns_remaining_window(mock_redis):
|
||||
"""ZSET oldest entry score = 100; window = 60s; now = 130s → remaining = 30s."""
|
||||
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
|
||||
mock_redis.zrange.return_value = [(b"member-1", 100.0)]
|
||||
with patch("libs.helper.time.time", return_value=130):
|
||||
assert rl.seconds_until_available("k1") == 30
|
||||
|
||||
|
||||
def test_seconds_until_available_floor_one_second(mock_redis):
|
||||
"""Even when math says <1s remaining, return at least 1 so client backs off measurably."""
|
||||
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
|
||||
mock_redis.zrange.return_value = [(b"member-1", 119.5)]
|
||||
with patch("libs.helper.time.time", return_value=180):
|
||||
# window expired (180 > 119.5+60=179.5 by 0.5s) — bucket is actually free now
|
||||
# but this method only called when is_rate_limited() == True; defensive floor.
|
||||
assert rl.seconds_until_available("k1") >= 1
|
||||
|
||||
|
||||
def test_seconds_until_available_empty_bucket(mock_redis):
|
||||
"""No entries → 1s sentinel (defensive; should not be reached when limited)."""
|
||||
rl = RateLimiter("rl:bearer:token", max_attempts=60, time_window=60, redis_client=mock_redis)
|
||||
mock_redis.zrange.return_value = []
|
||||
assert rl.seconds_until_available("k1") == 1
|
||||
|
||||
|
||||
@patch("libs.rate_limit._build_limiter")
|
||||
def test_enforce_bearer_rate_limit_passes_under_limit(mock_build):
|
||||
limiter = MagicMock()
|
||||
limiter.is_rate_limited.return_value = False
|
||||
mock_build.return_value = limiter
|
||||
enforce_bearer_rate_limit("hash-1")
|
||||
limiter.increment_rate_limit.assert_called_once_with("token:hash-1")
|
||||
|
||||
|
||||
@patch("libs.rate_limit._build_limiter")
|
||||
def test_enforce_bearer_rate_limit_raises_429_with_retry_after(mock_build):
|
||||
limiter = MagicMock()
|
||||
limiter.is_rate_limited.return_value = True
|
||||
limiter.seconds_until_available.return_value = 23
|
||||
mock_build.return_value = limiter
|
||||
with pytest.raises(TooManyRequests) as exc:
|
||||
enforce_bearer_rate_limit("hash-1")
|
||||
headers = dict(exc.value.get_response().headers)
|
||||
assert headers.get("Retry-After") == "23"
|
||||
body = exc.value.get_response().get_json() or {}
|
||||
assert body.get("error") == "rate_limited"
|
||||
assert body.get("retry_after_ms") == 23000
|
||||
93
api/tests/unit_tests/libs/test_workspace_member_helper.py
Normal file
93
api/tests/unit_tests/libs/test_workspace_member_helper.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Unit tests for require_workspace_member."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
|
||||
|
||||
|
||||
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
|
||||
return AuthContext(
|
||||
subject_type=SubjectType.ACCOUNT if account else SubjectType.EXTERNAL_SSO,
|
||||
subject_email="e@example.com",
|
||||
subject_issuer=None,
|
||||
account_id=uuid.uuid4() if account else None,
|
||||
scopes=frozenset({Scope.FULL}),
|
||||
token_id=uuid.uuid4(),
|
||||
source="oauth_account",
|
||||
expires_at=None,
|
||||
token_hash="h1",
|
||||
verified_tenants=dict(verified or {}),
|
||||
)
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_skips_when_enterprise_enabled(mock_cfg):
|
||||
mock_cfg.ENTERPRISE_ENABLED = True
|
||||
require_workspace_member(_ctx(), "t1")
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_skips_for_external_sso(mock_cfg):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
require_workspace_member(_ctx(account=False), "t1")
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.db")
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_uses_cached_ok_no_db_access(mock_cfg, mock_db):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
require_workspace_member(_ctx({"t1": True}), "t1")
|
||||
mock_db.session.execute.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.db")
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_uses_cached_denied(mock_cfg, mock_db):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
require_workspace_member(_ctx({"t1": False}), "t1")
|
||||
mock_db.session.execute.assert_not_called()
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_denies_when_no_membership(mock_cfg, mock_db, mock_record):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
require_workspace_member(_ctx({}), "t1")
|
||||
mock_record.assert_called_once_with("h1", "t1", False)
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_denies_when_account_inactive(mock_cfg, mock_db, mock_record):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
|
||||
]
|
||||
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
|
||||
require_workspace_member(_ctx({}), "t1")
|
||||
mock_record.assert_called_once_with("h1", "t1", False)
|
||||
|
||||
|
||||
@patch("libs.oauth_bearer.record_layer0_verdict")
|
||||
@patch("libs.oauth_bearer.db")
|
||||
@patch("libs.oauth_bearer.dify_config")
|
||||
def test_allows_active_member(mock_cfg, mock_db, mock_record):
|
||||
mock_cfg.ENTERPRISE_ENABLED = False
|
||||
mock_db.session.execute.side_effect = [
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
|
||||
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
|
||||
]
|
||||
require_workspace_member(_ctx({}), "t1")
|
||||
mock_record.assert_called_once_with("h1", "t1", True)
|
||||
@ -0,0 +1,57 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise.app_permitted_service import PermittedAppsPage, list_permitted_apps
|
||||
from services.errors.enterprise import EnterpriseAPIError
|
||||
|
||||
WRAPPER = "services.enterprise.app_permitted_service.EnterpriseService.WebAppAuth.list_externally_accessible_apps"
|
||||
|
||||
|
||||
def test_list_permitted_apps_decodes_camelcase_response():
|
||||
fake_body = {
|
||||
"data": [{"appId": "a"}, {"appId": "b"}],
|
||||
"total": 2,
|
||||
"hasMore": False,
|
||||
}
|
||||
with patch(WRAPPER, return_value=fake_body) as m:
|
||||
page = list_permitted_apps(page=1, limit=10)
|
||||
|
||||
assert isinstance(page, PermittedAppsPage)
|
||||
assert page.total == 2
|
||||
assert page.has_more is False
|
||||
assert page.app_ids == ["a", "b"]
|
||||
m.assert_called_once_with(page=1, limit=10, mode=None, name=None)
|
||||
|
||||
|
||||
def test_list_permitted_apps_passes_filters_to_wrapper():
|
||||
fake_body = {"data": [], "total": 0, "hasMore": False}
|
||||
with patch(WRAPPER, return_value=fake_body) as m:
|
||||
list_permitted_apps(page=2, limit=5, mode="workflow", name="alpha")
|
||||
|
||||
m.assert_called_once_with(page=2, limit=5, mode="workflow", name="alpha")
|
||||
|
||||
|
||||
def test_list_permitted_apps_503_on_ee_error():
|
||||
with patch(WRAPPER, side_effect=EnterpriseAPIError("boom", status_code=500)):
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
with pytest.raises(ServiceUnavailable):
|
||||
list_permitted_apps(page=1, limit=10)
|
||||
|
||||
|
||||
def test_list_permitted_apps_503_on_status_error():
|
||||
with patch(WRAPPER, side_effect=EnterpriseAPIError("bad key", status_code=401)):
|
||||
from werkzeug.exceptions import ServiceUnavailable
|
||||
|
||||
with pytest.raises(ServiceUnavailable):
|
||||
list_permitted_apps(page=1, limit=10)
|
||||
|
||||
|
||||
def test_list_permitted_apps_handles_empty_response():
|
||||
fake_body = {"data": [], "total": 0, "hasMore": False}
|
||||
with patch(WRAPPER, return_value=fake_body):
|
||||
page = list_permitted_apps(page=1, limit=10)
|
||||
assert page.app_ids == []
|
||||
assert page.total == 0
|
||||
assert page.has_more is False
|
||||
@ -188,6 +188,31 @@ class TestWebAppAuth:
|
||||
|
||||
req.send_request.assert_called_once_with("DELETE", "/webapp/clean", params={"appId": "a1"})
|
||||
|
||||
def test_list_externally_accessible_apps_minimal_call(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"data": [], "total": 0, "hasMore": False}
|
||||
result = EnterpriseService.WebAppAuth.list_externally_accessible_apps(page=1, limit=10)
|
||||
|
||||
assert result == {"data": [], "total": 0, "hasMore": False}
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/webapp/externally-accessible-apps",
|
||||
json={"page": 1, "limit": 10},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
def test_list_externally_accessible_apps_with_filters(self):
|
||||
with patch(f"{MODULE}.EnterpriseRequest") as req:
|
||||
req.send_request.return_value = {"data": [], "total": 0, "hasMore": False}
|
||||
EnterpriseService.WebAppAuth.list_externally_accessible_apps(page=2, limit=5, mode="workflow", name="alpha")
|
||||
|
||||
req.send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/webapp/externally-accessible-apps",
|
||||
json={"page": 2, "limit": 5, "mode": "workflow", "name": "alpha"},
|
||||
timeout=5.0,
|
||||
)
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
|
||||
@ -27,6 +27,11 @@ server {
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /openapi {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /files {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
|
||||
97
web/app/device/components/authorize-account.tsx
Normal file
97
web/app/device/components/authorize-account.tsx
Normal file
@ -0,0 +1,97 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useState } from 'react'
|
||||
import { deviceApproveAccount, deviceDenyAccount } from '@/service/device-flow'
|
||||
import { approveErrorCopy } from '../utils/error-copy'
|
||||
|
||||
type Props = {
|
||||
userCode: string
|
||||
accountEmail?: string
|
||||
defaultWorkspace?: string
|
||||
onApproved: () => void
|
||||
onDenied: () => void
|
||||
onError: (message: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthorizeAccount is the account-branch authorize screen. Called with a
|
||||
* live console session already established (user bounced through /signin).
|
||||
* Posts to /openapi/v1/oauth/device/{approve,deny}; these endpoints mint
|
||||
* the dfoa_ token server-side.
|
||||
*/
|
||||
const AuthorizeAccount: FC<Props> = ({
|
||||
userCode, accountEmail, defaultWorkspace, onApproved, onDenied, onError,
|
||||
}) => {
|
||||
const [busy, setBusy] = useState(false)
|
||||
|
||||
const approve = async () => {
|
||||
setBusy(true)
|
||||
try {
|
||||
await deviceApproveAccount(userCode)
|
||||
onApproved()
|
||||
}
|
||||
catch (e) {
|
||||
onError(approveErrorCopy(e))
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
const deny = async () => {
|
||||
setBusy(true)
|
||||
try {
|
||||
await deviceDenyAccount(userCode)
|
||||
onDenied()
|
||||
}
|
||||
catch (e) {
|
||||
onError(approveErrorCopy(e))
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Dify CLI (difyctl) is requesting access to your account.
|
||||
{' '}If you did not start this from your terminal, click Cancel.
|
||||
</p>
|
||||
</div>
|
||||
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
|
||||
{accountEmail && (
|
||||
<p className="text-sm text-text-secondary">
|
||||
Signed in as <span className="font-medium text-text-primary">{accountEmail}</span>
|
||||
</p>
|
||||
)}
|
||||
{defaultWorkspace && (
|
||||
<p className="mt-1 text-sm text-text-secondary">
|
||||
Default workspace: <span className="font-medium text-text-primary">{defaultWorkspace}</span>
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<button
|
||||
onClick={approve}
|
||||
disabled={busy}
|
||||
className="flex-1 rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Authorize
|
||||
</button>
|
||||
<button
|
||||
onClick={deny}
|
||||
disabled={busy}
|
||||
className="flex-1 rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default AuthorizeAccount
|
||||
98
web/app/device/components/authorize-sso.tsx
Normal file
98
web/app/device/components/authorize-sso.tsx
Normal file
@ -0,0 +1,98 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
import type { ApprovalContext } from '@/service/device-flow'
|
||||
import { approveExternal, fetchApprovalContext } from '@/service/device-flow'
|
||||
import { approveErrorCopy } from '../utils/error-copy'
|
||||
|
||||
type Props = {
|
||||
onApproved: () => void
|
||||
onError: (message: string) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* AuthorizeSSO is the external-SSO branch authorize screen. On mount it
|
||||
* fetches /openapi/v1/oauth/device/approval-context to learn subject_email,
|
||||
* issuer, user_code, and csrf_token from the device_approval_grant cookie.
|
||||
* On Approve click, posts /openapi/v1/oauth/device/approve-external with
|
||||
* the CSRF header.
|
||||
*
|
||||
* The user_code in state is bound to the cookie by server; we do not accept
|
||||
* one from the URL because the SSO branch deliberately detaches from the
|
||||
* pre-SSO ?user_code=... query param.
|
||||
*/
|
||||
const AuthorizeSSO: FC<Props> = ({ onApproved, onError }) => {
|
||||
const [ctx, setCtx] = useState<ApprovalContext | null>(null)
|
||||
const [busy, setBusy] = useState(false)
|
||||
const [loadErr, setLoadErr] = useState<string | null>(null)
|
||||
|
||||
useEffect(() => {
|
||||
let cancelled = false
|
||||
fetchApprovalContext()
|
||||
.then((c) => { if (!cancelled) setCtx(c) })
|
||||
.catch((e) => {
|
||||
if (!cancelled)
|
||||
setLoadErr(approveErrorCopy(e))
|
||||
})
|
||||
return () => { cancelled = true }
|
||||
}, [])
|
||||
|
||||
const approve = async () => {
|
||||
if (!ctx) return
|
||||
setBusy(true)
|
||||
try {
|
||||
await approveExternal(ctx, ctx.user_code)
|
||||
onApproved()
|
||||
}
|
||||
catch (e) {
|
||||
onError(approveErrorCopy(e))
|
||||
}
|
||||
finally {
|
||||
setBusy(false)
|
||||
}
|
||||
}
|
||||
|
||||
if (loadErr) {
|
||||
return (
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">This session is no longer valid</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Run <code className="rounded bg-components-panel-bg px-1">difyctl auth login</code> again to start a new sign-in.
|
||||
</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
if (!ctx) {
|
||||
return <div className="text-sm text-text-secondary">Loading session…</div>
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-6">
|
||||
<div>
|
||||
<h2 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h2>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Dify CLI (difyctl) is requesting access via SSO. If you did not start
|
||||
this from your terminal, close this tab.
|
||||
</p>
|
||||
</div>
|
||||
<div className="rounded-lg border border-components-panel-border bg-components-panel-bg px-4 py-3">
|
||||
<p className="text-sm text-text-secondary">
|
||||
Signed in as <span className="font-medium text-text-primary">{ctx.subject_email}</span>
|
||||
</p>
|
||||
<p className="mt-1 text-sm text-text-secondary">
|
||||
Issuer: <span className="font-medium text-text-primary">{ctx.subject_issuer}</span>
|
||||
</p>
|
||||
</div>
|
||||
<button
|
||||
onClick={approve}
|
||||
disabled={busy}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Authorize
|
||||
</button>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default AuthorizeSSO
|
||||
60
web/app/device/components/chooser.tsx
Normal file
60
web/app/device/components/chooser.tsx
Normal file
@ -0,0 +1,60 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useRouter } from '@/next/navigation'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
|
||||
type Props = {
|
||||
userCode: string
|
||||
ssoAvailable: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Chooser renders the two-button device-auth login selector. Account button
|
||||
* seeds postLoginRedirect + navigates to /signin so every existing account
|
||||
* login method (password / email-code / social OAuth / account-SSO) flows
|
||||
* through its usual plumbing. SSO button hits /openapi/v1/oauth/device/sso-initiate
|
||||
* directly — the SSO branch skips /signin entirely.
|
||||
*
|
||||
* v1.0 scope: only account-SSO honours postLoginRedirect (via sso-auth's
|
||||
* return_to plumbing). Password / email-code / social-OAuth users land on
|
||||
* /signin's default post-login target and manually return to the /device
|
||||
* URL printed by the CLI. That's not great UX; a follow-up milestone
|
||||
* generalises post-signin redirect to all methods.
|
||||
*/
|
||||
const Chooser: FC<Props> = ({ userCode, ssoAvailable }) => {
|
||||
const router = useRouter()
|
||||
|
||||
const onAccount = () => {
|
||||
setPostLoginRedirect(`/device?user_code=${encodeURIComponent(userCode)}`)
|
||||
router.push('/signin')
|
||||
}
|
||||
|
||||
const onSSO = () => {
|
||||
// Full-page navigation, not router.push — /openapi/v1/oauth/device/sso-initiate
|
||||
// issues a 302 to the IdP. Next's client router can't follow cross-
|
||||
// origin redirects; a plain window.location assignment handles it.
|
||||
window.location.href = `/openapi/v1/oauth/device/sso-initiate?user_code=${encodeURIComponent(userCode)}`
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex flex-col gap-3">
|
||||
<button
|
||||
onClick={onAccount}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover"
|
||||
>
|
||||
Sign in with Dify account
|
||||
</button>
|
||||
{ssoAvailable && (
|
||||
<button
|
||||
onClick={onSSO}
|
||||
className="rounded-lg border border-components-button-secondary-border bg-components-button-secondary-bg px-4 py-3 text-components-button-secondary-text font-medium hover:bg-components-button-secondary-bg-hover"
|
||||
>
|
||||
Sign in with SSO
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Chooser
|
||||
45
web/app/device/components/code-input.tsx
Normal file
45
web/app/device/components/code-input.tsx
Normal file
@ -0,0 +1,45 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import { useCallback } from 'react'
|
||||
import { normaliseUserCodeInput } from '../utils/user-code'
|
||||
|
||||
type Props = {
|
||||
value: string
|
||||
onChange: (normalised: string) => void
|
||||
disabled?: boolean
|
||||
autoFocus?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* CodeInput renders the user_code text field with live normalisation
|
||||
* (uppercase, reduced alphabet, XXXX-XXXX hyphenation).
|
||||
*
|
||||
* The onChange callback receives the normalised value only — the parent does
|
||||
* not need to run validation itself.
|
||||
*/
|
||||
const CodeInput: FC<Props> = ({ value, onChange, disabled, autoFocus }) => {
|
||||
const handle = useCallback((raw: string) => {
|
||||
onChange(normaliseUserCodeInput(raw))
|
||||
}, [onChange])
|
||||
|
||||
return (
|
||||
<input
|
||||
type="text"
|
||||
inputMode="text"
|
||||
autoCapitalize="characters"
|
||||
autoComplete="off"
|
||||
spellCheck={false}
|
||||
placeholder="ABCD-1234"
|
||||
maxLength={9}
|
||||
aria-label="one-time code"
|
||||
className="w-full rounded-lg border border-components-input-border-normal bg-components-input-bg-normal px-4 py-3 text-center text-2xl font-mono tracking-wider text-text-primary focus:border-components-input-border-active focus:outline-none"
|
||||
value={value}
|
||||
disabled={disabled}
|
||||
autoFocus={autoFocus}
|
||||
onChange={e => handle(e.target.value)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default CodeInput
|
||||
215
web/app/device/page.tsx
Normal file
215
web/app/device/page.tsx
Normal file
@ -0,0 +1,215 @@
|
||||
'use client'
|
||||
|
||||
import { useEffect, useState } from 'react'
|
||||
import { usePathname, useRouter, useSearchParams } from '@/next/navigation'
|
||||
import { useQuery } from '@tanstack/react-query'
|
||||
import { systemFeaturesQueryOptions } from '@/service/system-features'
|
||||
import { commonQueryKeys, userProfileQueryOptions } from '@/service/use-common'
|
||||
import { post } from '@/service/base'
|
||||
import type { ICurrentWorkspace } from '@/models/common'
|
||||
import { deviceLookup } from '@/service/device-flow'
|
||||
import CodeInput from './components/code-input'
|
||||
import Chooser from './components/chooser'
|
||||
import AuthorizeAccount from './components/authorize-account'
|
||||
import AuthorizeSSO from './components/authorize-sso'
|
||||
import { isValidUserCode } from './utils/user-code'
|
||||
import { classifyLookupError } from './utils/error-copy'
|
||||
|
||||
type View =
|
||||
| { kind: 'code_entry' }
|
||||
| { kind: 'chooser'; userCode: string }
|
||||
| { kind: 'authorize_account'; userCode: string }
|
||||
| { kind: 'authorize_sso' }
|
||||
| { kind: 'success' }
|
||||
| { kind: 'error_expired' }
|
||||
| { kind: 'error_rate_limited' }
|
||||
| { kind: 'error_lookup_failed' }
|
||||
|
||||
export default function DevicePage() {
|
||||
const searchParams = useSearchParams()
|
||||
const router = useRouter()
|
||||
const pathname = usePathname()
|
||||
const urlUserCode = (searchParams.get('user_code') || '').trim().toUpperCase()
|
||||
const ssoVerified = searchParams.get('sso_verified') === '1'
|
||||
|
||||
const [typed, setTyped] = useState('')
|
||||
const [view, setView] = useState<View>({ kind: 'code_entry' })
|
||||
const [errMsg, setErrMsg] = useState<string | null>(null)
|
||||
|
||||
// Account subject + workspace identity (for the authorize-account screen).
|
||||
// Logged-out is a valid landing state on /device — disable refetch storms
|
||||
// and skip workspace probe until profile resolves (avoids /current + chained
|
||||
// /refresh-token 401 loops while the user is still entering the code).
|
||||
const { data: userResp, isError: profileErr } = useQuery({
|
||||
...userProfileQueryOptions(),
|
||||
throwOnError: false,
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
refetchOnMount: false,
|
||||
})
|
||||
const account = userResp?.profile
|
||||
const { data: currentWorkspace } = useQuery<ICurrentWorkspace>({
|
||||
queryKey: commonQueryKeys.currentWorkspace,
|
||||
queryFn: () => post<ICurrentWorkspace>('/workspaces/current'),
|
||||
enabled: !!account && !profileErr,
|
||||
retry: false,
|
||||
refetchOnWindowFocus: false,
|
||||
})
|
||||
const { data: sys } = useQuery(systemFeaturesQueryOptions())
|
||||
// Device-flow SSO branch uses external-user (webapp) SSO, not console SSO —
|
||||
// backend mints EXTERNAL_SSO tokens via Enterprise's external ACS. Gate on
|
||||
// webapp_auth.{enabled, allow_sso} + a configured webapp SSO protocol.
|
||||
const ssoAvailable = !!sys?.webapp_auth?.enabled
|
||||
&& !!sys?.webapp_auth?.allow_sso
|
||||
&& (sys?.webapp_auth?.sso_config?.protocol || '') !== ''
|
||||
|
||||
// URL-driven view transitions. Only advances while the user is still on
|
||||
// the entry/chooser screens — never clobbers terminal views (success /
|
||||
// error_expired / authorize_*) when userProfile refetches.
|
||||
// After consuming the params, scrub them from the URL so they don't
|
||||
// leak via history / Referer / server logs (RFC 8628 §5.4).
|
||||
useEffect(() => {
|
||||
if (view.kind !== 'code_entry' && view.kind !== 'chooser') return
|
||||
// Post-login bounce: chooser holds the typed code, account just loaded.
|
||||
// The URL was already scrubbed on the first effect run, so urlUserCode
|
||||
// is empty here — advance using the userCode stashed in view state.
|
||||
if (view.kind === 'chooser' && account) {
|
||||
setView({ kind: 'authorize_account', userCode: view.userCode })
|
||||
return
|
||||
}
|
||||
let consumed = false
|
||||
if (ssoVerified) {
|
||||
setView({ kind: 'authorize_sso' })
|
||||
consumed = true
|
||||
}
|
||||
else if (urlUserCode && isValidUserCode(urlUserCode)) {
|
||||
if (account)
|
||||
setView({ kind: 'authorize_account', userCode: urlUserCode })
|
||||
else
|
||||
setView({ kind: 'chooser', userCode: urlUserCode })
|
||||
consumed = true
|
||||
}
|
||||
if (consumed && (urlUserCode || ssoVerified))
|
||||
router.replace(pathname)
|
||||
}, [urlUserCode, ssoVerified, account, view, router, pathname])
|
||||
|
||||
const onContinue = async () => {
|
||||
if (!isValidUserCode(typed)) return
|
||||
try {
|
||||
const reply = await deviceLookup(typed)
|
||||
if (!reply.valid) {
|
||||
setView({ kind: 'error_expired' })
|
||||
return
|
||||
}
|
||||
}
|
||||
catch (e) {
|
||||
const outcome = classifyLookupError(e)
|
||||
if (outcome === 'rate_limited')
|
||||
setView({ kind: 'error_rate_limited' })
|
||||
else if (outcome === 'failed')
|
||||
setView({ kind: 'error_lookup_failed' })
|
||||
else
|
||||
setView({ kind: 'error_expired' })
|
||||
return
|
||||
}
|
||||
if (account) setView({ kind: 'authorize_account', userCode: typed })
|
||||
else setView({ kind: 'chooser', userCode: typed })
|
||||
}
|
||||
|
||||
return (
|
||||
<main className="mx-auto flex min-h-screen max-w-lg flex-col items-center justify-center px-6 py-10">
|
||||
<div className="w-full rounded-xl border border-components-panel-border bg-components-panel-bg p-8 shadow-sm">
|
||||
{view.kind === 'code_entry' && (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Authorize Dify CLI</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Enter the code shown in your terminal.
|
||||
</p>
|
||||
</div>
|
||||
<CodeInput value={typed} onChange={setTyped} autoFocus />
|
||||
<button
|
||||
onClick={onContinue}
|
||||
disabled={!isValidUserCode(typed)}
|
||||
className="rounded-lg bg-components-button-primary-bg px-4 py-3 text-components-button-primary-text font-medium hover:bg-components-button-primary-bg-hover disabled:opacity-50"
|
||||
>
|
||||
Continue
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'chooser' && (
|
||||
<div className="flex flex-col gap-5">
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Sign in to authorize</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Code <span className="font-mono">{view.userCode}</span> is valid. Choose how to sign in.
|
||||
</p>
|
||||
</div>
|
||||
<Chooser userCode={view.userCode} ssoAvailable={ssoAvailable} />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'authorize_account' && (
|
||||
<AuthorizeAccount
|
||||
userCode={view.userCode}
|
||||
accountEmail={account?.email}
|
||||
defaultWorkspace={currentWorkspace?.name}
|
||||
onApproved={() => setView({ kind: 'success' })}
|
||||
onDenied={() => setView({ kind: 'error_expired' })}
|
||||
onError={e => setErrMsg(e)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{view.kind === 'authorize_sso' && (
|
||||
<AuthorizeSSO
|
||||
onApproved={() => setView({ kind: 'success' })}
|
||||
onError={e => setErrMsg(e)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{view.kind === 'success' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">You're signed in</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">Return to your terminal to continue.</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'error_expired' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">This code is no longer valid</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
The code may have expired or already been used. Run
|
||||
{' '}
|
||||
<code className="rounded bg-components-panel-bg px-1">difyctl auth login</code>
|
||||
{' '}
|
||||
again to get a new one.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'error_rate_limited' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Too many attempts</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
We've received too many requests for this code. Wait a moment and try again.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{view.kind === 'error_lookup_failed' && (
|
||||
<div>
|
||||
<h1 className="text-2xl font-semibold text-text-primary">Could not verify the code</h1>
|
||||
<p className="mt-2 text-sm text-text-secondary">
|
||||
Something went wrong on our side. Try again in a moment.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{errMsg && (
|
||||
<p className="mt-4 text-sm text-text-destructive">{errMsg}</p>
|
||||
)}
|
||||
</div>
|
||||
</main>
|
||||
)
|
||||
}
|
||||
41
web/app/device/utils/error-copy.ts
Normal file
41
web/app/device/utils/error-copy.ts
Normal file
@ -0,0 +1,41 @@
|
||||
// Translate a DeviceFlowError (or any thrown value) into user-facing copy.
|
||||
// Centralised so account/SSO branches surface the same words for the same
|
||||
// failure mode and so a new server error code can be wired up here once.
|
||||
|
||||
import { DeviceFlowError } from '@/service/device-flow'
|
||||
|
||||
const APPROVE_COPY: Record<string, string> = {
|
||||
rate_limited: 'Too many attempts. Wait a moment and try again.',
|
||||
no_session: 'Your session has expired. Run difyctl auth login again to start over.',
|
||||
invalid_session: 'Your session has expired. Run difyctl auth login again to start over.',
|
||||
session_already_consumed: 'This session was already used. Run difyctl auth login again.',
|
||||
csrf_mismatch: 'Could not verify the request. Refresh the page and try again.',
|
||||
forbidden: 'Could not verify the request. Refresh the page and try again.',
|
||||
expired_or_unknown: 'This code is no longer valid.',
|
||||
not_found: 'This code is no longer valid.',
|
||||
user_code_mismatch: 'This code does not match the active session. Run difyctl auth login again.',
|
||||
user_code_not_pending: 'This code was already approved or denied.',
|
||||
already_resolved: 'This code was already approved or denied.',
|
||||
state_lost: 'The flow expired before approval completed. Run difyctl auth login again.',
|
||||
approve_in_progress: 'An approval is already in progress for this code.',
|
||||
conflict: 'This code is no longer in a state we can approve.',
|
||||
server_error: 'Something went wrong on our side. Try again in a moment.',
|
||||
}
|
||||
|
||||
const DEFAULT_MESSAGE = 'Could not complete the request. Please try again.'
|
||||
|
||||
export function approveErrorCopy(err: unknown): string {
|
||||
if (err instanceof DeviceFlowError)
|
||||
return APPROVE_COPY[err.code] ?? DEFAULT_MESSAGE
|
||||
return DEFAULT_MESSAGE
|
||||
}
|
||||
|
||||
export type LookupOutcome = 'expired' | 'rate_limited' | 'failed'
|
||||
|
||||
export function classifyLookupError(err: unknown): LookupOutcome {
|
||||
if (err instanceof DeviceFlowError) {
|
||||
if (err.code === 'rate_limited' || err.status === 429) return 'rate_limited'
|
||||
if (err.code === 'server_error' || err.status >= 500) return 'failed'
|
||||
}
|
||||
return 'expired'
|
||||
}
|
||||
37
web/app/device/utils/user-code.ts
Normal file
37
web/app/device/utils/user-code.ts
Normal file
@ -0,0 +1,37 @@
|
||||
// user-code.ts — input normalisation + validation for the RFC 8628
|
||||
// 8-character user_code format the CLI prints to stderr.
|
||||
//
|
||||
// Format: XXXX-XXXX, uppercase, reduced alphabet (no 0/O, 1/I/l, 2/Z). Low
|
||||
// entropy by design — humans type it — so the server-side rate-limit + TTL +
|
||||
// single-use properties are what defend it, not the alphabet.
|
||||
|
||||
export const USER_CODE_ALPHABET = 'ABCDEFGHJKLMNPQRSTUVWXY3456789' // excludes 0 O 1 I L 2 Z
|
||||
|
||||
/**
|
||||
* normaliseUserCodeInput prepares raw input for display in the code field:
|
||||
* strips non-alphanumerics, uppercases, drops disallowed characters, and
|
||||
* inserts the hyphen after the fourth accepted char.
|
||||
*
|
||||
* Returns at most 9 chars ("XXXX-XXXX"); longer input is truncated.
|
||||
*/
|
||||
export function normaliseUserCodeInput(raw: string): string {
|
||||
const cleaned: string[] = []
|
||||
for (const ch of raw.toUpperCase()) {
|
||||
if (USER_CODE_ALPHABET.includes(ch))
|
||||
cleaned.push(ch)
|
||||
if (cleaned.length === 8)
|
||||
break
|
||||
}
|
||||
if (cleaned.length <= 4)
|
||||
return cleaned.join('')
|
||||
return `${cleaned.slice(0, 4).join('')}-${cleaned.slice(4).join('')}`
|
||||
}
|
||||
|
||||
/**
|
||||
* isValidUserCode tests whether the normalised form is a complete XXXX-XXXX
|
||||
* token suitable for submission to /openapi/v1/oauth/device/lookup.
|
||||
*/
|
||||
export function isValidUserCode(normalised: string): boolean {
|
||||
return /^[A-Z0-9]{4}-[A-Z0-9]{4}$/.test(normalised)
|
||||
&& [...normalised.replace('-', '')].every(c => USER_CODE_ALPHABET.includes(c))
|
||||
}
|
||||
@ -1,15 +1,68 @@
|
||||
let postLoginRedirect: string | null = null
|
||||
// Persists target across full-page redirects within the same tab (social
|
||||
// OAuth, SSO IdP bounce). sessionStorage is tab-scoped so concurrent
|
||||
// /device tabs don't clobber each other. 15-min TTL drops stale values.
|
||||
// Same-origin + exact-path whitelist prevents open-redirect.
|
||||
//
|
||||
// Signup-via-email-link opening in a new tab is out of scope — that tab
|
||||
// starts with an empty sessionStorage and falls to /apps default.
|
||||
|
||||
const KEY = 'dify_post_login_redirect'
|
||||
const TTL_MS = 15 * 60 * 1000
|
||||
|
||||
const ALLOWED: Record<string, ReadonlySet<string>> = {
|
||||
'/device': new Set(['user_code', 'sso_verified']),
|
||||
'/account/oauth/authorize': new Set(['client_id', 'scope', 'state', 'redirect_uri']),
|
||||
}
|
||||
|
||||
function validate(target: string): string | null {
|
||||
if (typeof window === 'undefined') return null
|
||||
try {
|
||||
const url = new URL(target, window.location.origin)
|
||||
if (url.origin !== window.location.origin) return null
|
||||
const allowedKeys = ALLOWED[url.pathname]
|
||||
if (!allowedKeys) return null
|
||||
for (const key of url.searchParams.keys()) {
|
||||
if (!allowedKeys.has(key)) return null
|
||||
}
|
||||
return url.pathname + (url.search || '')
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
export const setPostLoginRedirect = (value: string | null) => {
|
||||
postLoginRedirect = value
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = () => {
|
||||
if (postLoginRedirect) {
|
||||
const redirectUrl = postLoginRedirect
|
||||
postLoginRedirect = null
|
||||
return redirectUrl
|
||||
if (typeof window === 'undefined') return
|
||||
if (value === null) {
|
||||
try { sessionStorage.removeItem(KEY) } catch {}
|
||||
return
|
||||
}
|
||||
const safe = validate(value)
|
||||
if (!safe) return
|
||||
try {
|
||||
sessionStorage.setItem(KEY, JSON.stringify({ target: safe, ts: Date.now() }))
|
||||
}
|
||||
catch {}
|
||||
}
|
||||
|
||||
export const resolvePostLoginRedirect = (): string | null => {
|
||||
if (typeof window === 'undefined') return null
|
||||
let raw: string | null = null
|
||||
try {
|
||||
raw = sessionStorage.getItem(KEY)
|
||||
sessionStorage.removeItem(KEY)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
if (!raw) return null
|
||||
try {
|
||||
const parsed = JSON.parse(raw)
|
||||
if (typeof parsed?.target !== 'string' || typeof parsed?.ts !== 'number') return null
|
||||
if (Date.now() - parsed.ts > TTL_MS) return null
|
||||
return validate(parsed.target)
|
||||
}
|
||||
catch {
|
||||
return null
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
@ -30,6 +30,20 @@ const nextConfig: NextConfig = {
|
||||
},
|
||||
]
|
||||
},
|
||||
// Anti-framing for device-flow surfaces. A framed /device page could UI-trick
|
||||
// a victim with a valid device_approval_grant cookie into approving a
|
||||
// device_code — functionally CSRF, bypasses the double-submit token. Deny
|
||||
// framing outright on every device-flow route; no trusted embedder exists.
|
||||
async headers() {
|
||||
const antiFrame = [
|
||||
{ key: 'X-Frame-Options', value: 'DENY' },
|
||||
{ key: 'Content-Security-Policy', value: "frame-ancestors 'none'" },
|
||||
]
|
||||
return [
|
||||
{ source: '/device', headers: antiFrame },
|
||||
{ source: '/device/:path*', headers: antiFrame },
|
||||
]
|
||||
},
|
||||
output: 'standalone',
|
||||
compiler: {
|
||||
removeConsole: isDev ? false : { exclude: ['warn', 'error'] },
|
||||
|
||||
@ -794,6 +794,11 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
|
||||
const [refreshErr] = await asyncRunSafe(refreshAccessTokenOrReLogin(TIME_OUT))
|
||||
if (refreshErr === null)
|
||||
return baseFetch<T>(url, options, otherOptionsForBaseFetch)
|
||||
// /device is the device-flow chooser; logged-out is a valid state
|
||||
// there. Redirecting to /signin loses the user_code context and
|
||||
// the post-login flow lands on /apps instead of returning here.
|
||||
if (location.pathname === `${basePath}/device`)
|
||||
return Promise.reject(err)
|
||||
if (location.pathname !== `${basePath}/signin` || !IS_CE_EDITION) {
|
||||
jumpTo(loginUrl)
|
||||
return Promise.reject(err)
|
||||
|
||||
134
web/service/device-flow.ts
Normal file
134
web/service/device-flow.ts
Normal file
@ -0,0 +1,134 @@
|
||||
// Web-side calls into the Dify device-flow endpoints. All routes now sit
|
||||
// under /openapi/v1/oauth/device/* (Phase G of the openapi migration). The
|
||||
// approve/deny endpoints still require the console session cookie + CSRF
|
||||
// token; lookup is unauthenticated; the SSO branch uses cookie + per-flow
|
||||
// CSRF baked into the approval-context response.
|
||||
//
|
||||
// /openapi/v1/oauth/device/lookup (public — GET)
|
||||
// /openapi/v1/oauth/device/approve (cookie + CSRF — POST)
|
||||
// /openapi/v1/oauth/device/deny (cookie + CSRF — POST)
|
||||
// /openapi/v1/oauth/device/approval-context (cookie — GET)
|
||||
// /openapi/v1/oauth/device/approve-external (cookie + per-flow CSRF — POST)
|
||||
//
|
||||
// /openapi/v1/* is its own URL prefix, so we bypass service/base's
|
||||
// API_PREFIX (which targets /console/api) and call fetch directly.
|
||||
|
||||
import Cookies from 'js-cookie'
|
||||
import { CSRF_COOKIE_NAME, CSRF_HEADER_NAME } from '@/config'
|
||||
|
||||
const DEVICE_BASE = '/openapi/v1/oauth/device'
|
||||
|
||||
// Typed error thrown by every wrapper here. The page/component layer
|
||||
// switches on `code` to choose user-facing copy / view; never render
|
||||
// `status` or raw body to the user.
|
||||
export class DeviceFlowError extends Error {
|
||||
constructor(public code: string, public status: number) {
|
||||
super(code)
|
||||
this.name = 'DeviceFlowError'
|
||||
}
|
||||
}
|
||||
|
||||
// Translate a non-2xx fetch Response into a DeviceFlowError. Honours the
|
||||
// server contract `{"error": "<code>"}` and falls back to a status-class
|
||||
// code so callers can still dispatch (rate_limited / server_error / ...).
|
||||
async function failFromResponse(res: Response): Promise<never> {
|
||||
let serverCode = ''
|
||||
try {
|
||||
const body = await res.clone().json()
|
||||
if (body && typeof body.error === 'string') serverCode = body.error
|
||||
}
|
||||
catch { /* non-JSON body — fall through to status mapping */ }
|
||||
|
||||
const code = serverCode || statusFallbackCode(res.status)
|
||||
throw new DeviceFlowError(code, res.status)
|
||||
}
|
||||
|
||||
function statusFallbackCode(status: number): string {
|
||||
if (status === 429) return 'rate_limited'
|
||||
if (status === 401) return 'no_session'
|
||||
if (status === 403) return 'forbidden'
|
||||
if (status === 404) return 'not_found'
|
||||
if (status === 409) return 'conflict'
|
||||
if (status >= 500) return 'server_error'
|
||||
return 'unknown'
|
||||
}
|
||||
|
||||
function consoleCsrfHeader(): Record<string, string> {
|
||||
return { [CSRF_HEADER_NAME]: Cookies.get(CSRF_COOKIE_NAME()) || '' }
|
||||
}
|
||||
|
||||
// ----- Account branch --------------------------------------------------------
|
||||
|
||||
export type DeviceLookupReply = {
|
||||
valid: boolean
|
||||
expires_in_remaining: number
|
||||
client_id: string
|
||||
}
|
||||
|
||||
export async function deviceLookup(user_code: string): Promise<DeviceLookupReply> {
|
||||
const res = await fetch(`${DEVICE_BASE}/lookup?user_code=${encodeURIComponent(user_code)}`, {
|
||||
method: 'GET',
|
||||
})
|
||||
if (!res.ok) await failFromResponse(res)
|
||||
return res.json()
|
||||
}
|
||||
|
||||
export async function deviceApproveAccount(user_code: string): Promise<{ status: 'approved' }> {
|
||||
const res = await fetch(`${DEVICE_BASE}/approve`, {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...consoleCsrfHeader(),
|
||||
},
|
||||
body: JSON.stringify({ user_code }),
|
||||
})
|
||||
if (!res.ok) await failFromResponse(res)
|
||||
return res.json()
|
||||
}
|
||||
|
||||
export async function deviceDenyAccount(user_code: string): Promise<{ status: 'denied' }> {
|
||||
const res = await fetch(`${DEVICE_BASE}/deny`, {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
...consoleCsrfHeader(),
|
||||
},
|
||||
body: JSON.stringify({ user_code }),
|
||||
})
|
||||
if (!res.ok) await failFromResponse(res)
|
||||
return res.json()
|
||||
}
|
||||
|
||||
// ----- SSO branch (cookie-authed via /openapi/v1/oauth/device/*) -----------
|
||||
|
||||
export type ApprovalContext = {
|
||||
subject_email: string
|
||||
subject_issuer: string
|
||||
user_code: string
|
||||
csrf_token: string
|
||||
expires_at: string
|
||||
}
|
||||
|
||||
export async function fetchApprovalContext(): Promise<ApprovalContext> {
|
||||
const res = await fetch(`${DEVICE_BASE}/approval-context`, {
|
||||
method: 'GET',
|
||||
credentials: 'include',
|
||||
})
|
||||
if (!res.ok) await failFromResponse(res)
|
||||
return res.json()
|
||||
}
|
||||
|
||||
export async function approveExternal(ctx: ApprovalContext, user_code: string): Promise<void> {
|
||||
const res = await fetch(`${DEVICE_BASE}/approve-external`, {
|
||||
method: 'POST',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-CSRF-Token': ctx.csrf_token,
|
||||
},
|
||||
body: JSON.stringify({ user_code }),
|
||||
})
|
||||
if (!res.ok) await failFromResponse(res)
|
||||
}
|
||||
Reference in New Issue
Block a user