Compare commits

..

2 Commits

Author SHA1 Message Date
14d4f6f2d4 Merge branch 'feat/openapi-auth-pipeline' into build/cli 2026-05-26 03:39:21 -07:00
9b25980b09 feat(openapi): redesign auth pipeline — one pipeline per token type with PipelineRouter
Replace the single mutable-context Pipeline with a two-phase, condition-driven
system dispatched by token type.

New architecture:
- TokenType(StrEnum) replaces source: str on AuthContext / TokenKind
- AuthPipeline: pure prepare→auth step runner; no guard()
- PipelineRoute: binds AuthPipeline to an optional required_edition gate
- PipelineRouter: single guard() entry point; runs edition/license/token-type
  pre-gates then dispatches to the registered pipeline for the token type
- Cond / When: composable predicates for conditional step dispatch
- AuthData: frozen Pydantic model produced by the prepare phase; carries
  token_id so endpoints don't need to call get_auth_ctx() for identity fields
- Edition enum + current_edition(): CE / EE / SAAS discriminator

Two pipelines in composition.py:
- account_pipeline  — OAUTH_ACCOUNT tokens
- external_sso_pipeline — OAUTH_EXTERNAL_SSO tokens (EE enforced at route level)

All /openapi/v1 endpoints migrated to auth_router.guard().
Old context.py, steps.py, strategies.py, surface_gate.py deleted.
WORKSPACE_READ scope added; cached_verdicts renamed to membership_cache.
2026-05-26 03:16:28 -07:00
45 changed files with 1677 additions and 1619 deletions

View File

@ -4,7 +4,7 @@ from datetime import UTC, datetime
from flask import request
from flask_restx import Resource
from werkzeug.exceptions import BadRequest, NotFound
from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import (
@ -17,18 +17,17 @@ from controllers.openapi._models import (
SessionRow,
WorkspacePayload,
)
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
SubjectType,
Scope,
TokenType,
get_auth_ctx,
validate_bearer,
)
from libs.rate_limit import (
LIMIT_ME_PER_ACCOUNT,
LIMIT_ME_PER_EMAIL,
enforce,
)
from services.account_service import AccountService, TenantService
@ -42,32 +41,18 @@ from services.oauth_device_flow import (
@openapi_ns.route("/account")
class AccountApi(Resource):
@openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
ctx = get_auth_ctx()
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}")
else:
enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}")
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email,
subject_issuer=ctx.subject_issuer,
account=None,
workspaces=[],
default_workspace_id=None,
).model_dump(mode="json")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None
memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else []
account_id_str = str(auth_data.account_id) if auth_data.account_id else None
account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None
memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else []
default_ws_id = _pick_default_workspace(memberships)
return AccountResponse(
subject_type=ctx.subject_type,
subject_email=ctx.subject_email or (account.email if account else None),
subject_type="account",
subject_email=account.email if account else None,
account=_account_payload(account) if account else None,
workspaces=[_workspace_payload(m) for m in memberships],
default_workspace_id=default_ws_id,
@ -77,19 +62,17 @@ class AccountApi(Resource):
@openapi_ns.route("/account/sessions/self")
class AccountSessionsSelfApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
revoke_oauth_token(db.session, redis_client, str(ctx.token_id))
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def delete(self, *, auth_data: AuthData):
revoke_oauth_token(db.session, redis_client, str(auth_data.token_id))
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
@openapi_ns.route("/account/sessions")
class AccountSessionsApi(Resource):
@openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def get(self):
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
ctx = get_auth_ctx()
now = datetime.now(UTC)
page = int(request.args.get("page", "1"))
@ -122,10 +105,9 @@ class AccountSessionsApi(Resource):
@openapi_ns.route("/account/sessions/<string:session_id>")
class AccountSessionByIdApi(Resource):
@openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
def delete(self, session_id: str):
@auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def delete(self, session_id: str, *, auth_data: AuthData):
ctx = get_auth_ctx()
_require_oauth_subject(ctx)
# 404 (not 403) on cross-subject so the endpoint doesn't leak
# token IDs that belong to other subjects.
@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource):
return RevokeResponse(status="revoked").model_dump(mode="json"), 200
def _require_oauth_subject(ctx: AuthContext) -> None:
if not ctx.source.startswith("oauth"):
raise BadRequest(
"this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs"
)
def _iso(dt: datetime | None) -> str | None:
if dt is None:
return None

View File

@ -16,7 +16,8 @@ import services
from controllers.openapi import openapi_ns
from controllers.openapi._audit import emit_app_run
from controllers.openapi._models import AppRunRequest
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.service_api.app.error import (
AppUnavailableError,
CompletionRequestError,
@ -124,8 +125,11 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = {
class AppRunApi(Resource):
@openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__])
@openapi_ns.response(200, "Run result (SSE stream)")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
body = request.get_json(silent=True) or {}
try:
payload = AppRunRequest.model_validate(body)
@ -158,8 +162,11 @@ class AppRunApi(Resource):
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/stop")
class AppRunTaskStopApi(Resource):
@openapi_ns.response(200, "Task stopped")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
AppQueueManager.set_stop_flag_no_user_check(task_id)
GraphEngineManager(redis_client).send_stop_command(task_id)
return {"result": "success"}

View File

@ -1,9 +1,4 @@
"""GET /openapi/v1/apps and per-app reads.
Decorator order: `method_decorators` is innermost-first. `validate_bearer`
is last → outermost → publishes the auth ContextVar before `require_scope`
reads it.
"""
"""GET /openapi/v1/apps and per-app reads."""
from __future__ import annotations
@ -28,31 +23,17 @@ from controllers.openapi._models import (
AppListRow,
TagItem,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from controllers.service_api.app.error import AppUnavailableError
from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
AuthContext,
Scope,
SubjectType,
get_auth_ctx,
require_scope,
require_workspace_member,
validate_bearer,
)
from libs.oauth_bearer import Scope, TokenType
from models import App
from services.account_service import TenantService
from services.app_service import AppListParams, AppService
from services.tag_service import TagService
_APPS_READ_DECORATORS = [
require_scope(Scope.APPS_READ),
accept_subjects(SubjectType.ACCOUNT),
validate_bearer(accept=ACCEPT_USER_ANY),
]
_ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"})
@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = {
class AppReadResource(Resource):
"""Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks."""
method_decorators = _APPS_READ_DECORATORS
def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]:
ctx: AuthContext = get_auth_ctx()
"""Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks."""
def _load(self, app_id: str, workspace_id: str | None = None) -> App:
try:
parsed_uuid = _uuid.UUID(app_id)
is_uuid = True
@ -99,8 +76,7 @@ class AppReadResource(Resource):
raise Conflict("".join(lines))
app = matches[0]
require_workspace_member(ctx, str(app.tenant_id))
return app, ctx
return app
def parameters_payload(app: App) -> dict:
@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict:
class AppDescribeApi(AppReadResource):
@openapi_ns.doc(params=query_params_from_model(AppDescribeQuery))
@openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__])
def get(self, app_id: str):
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, app_id: str, *, auth_data: AuthData):
try:
query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
app, _ = self._load(app_id, workspace_id=query.workspace_id)
app = self._load(app_id, workspace_id=query.workspace_id)
requested = query.fields
want_info = requested is None or "info" in requested
@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource):
@openapi_ns.route("/apps")
class AppListApi(Resource):
method_decorators = _APPS_READ_DECORATORS
@openapi_ns.doc(params=query_params_from_model(AppListQuery))
@openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__])
def get(self):
ctx: AuthContext = get_auth_ctx()
@auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
try:
query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:
raise UnprocessableEntity(exc.json())
workspace_id = query.workspace_id
require_workspace_member(ctx, workspace_id)
empty = (
AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump(
@ -237,7 +210,7 @@ class AppListApi(Resource):
openapi_visible=True,
)
pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params)
pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params)
if pagination is None:
return empty

View File

@ -18,37 +18,27 @@ from controllers.openapi._models import (
PermittedExternalAppsListQuery,
PermittedExternalAppsListResponse,
)
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData, Edition
from extensions.ext_database import db
from libs.device_flow_security import enterprise_only
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
Scope,
SubjectType,
require_scope,
validate_bearer,
)
from libs.oauth_bearer import Scope, TokenType
from models import App
from services.account_service import TenantService
from services.app_service import AppService
from services.enterprise.app_permitted_service import list_permitted_apps
from services.openapi.license_gate import license_required
@openapi_ns.route("/permitted-external-apps")
class PermittedExternalAppsListApi(Resource):
method_decorators = [
require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL),
license_required,
accept_subjects(SubjectType.EXTERNAL_SSO),
validate_bearer(accept=ACCEPT_USER_ANY),
enterprise_only,
]
@openapi_ns.response(
200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__]
)
def get(self):
@auth_router.guard(
scope=Scope.APPS_READ_PERMITTED_EXTERNAL,
allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}),
edition=frozenset({Edition.EE}),
)
def get(self, *, auth_data: AuthData):
try:
query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True))
except ValidationError as exc:

View File

@ -1,3 +1,3 @@
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.openapi.auth.composition import auth_router
__all__ = ["OAUTH_BEARER_PIPELINE"]
__all__ = ["auth_router"]

View File

@ -1,46 +1,64 @@
"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints.
Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative
paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip
the pipeline and use `validate_bearer + require_scope + require_workspace_member`
inline — they don't need `AppAuthzCheck`/`CallerMount`.
"""
from __future__ import annotations
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
WEBAPP_AUTH_ENABLED,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
AppAuthzStrategy,
EndUserMounter,
MembershipStrategy,
from controllers.openapi.auth.data import Edition
from controllers.openapi.auth.flow import When
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from controllers.openapi.auth.prepare import (
build_external_identity,
load_account,
load_app,
load_app_access_mode,
load_tenant,
resolve_external_user,
)
from libs.oauth_bearer import SubjectType
from services.feature_service import FeatureService
def _resolve_app_authz_strategy() -> AppAuthzStrategy:
if FeatureService.get_system_features().webapp_auth.enabled:
return AclStrategy()
return MembershipStrategy()
OAUTH_BEARER_PIPELINE = Pipeline(
BearerCheck(),
SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})),
ScopeCheck(),
AppResolver(),
WorkspaceMembershipCheck(),
AppAuthzCheck(_resolve_app_authz_strategy),
CallerMount(AccountMounter(), EndUserMounter()),
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
)
from libs.oauth_bearer import TokenType
account_pipeline = AuthPipeline(
prepare=[
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
load_account, # unconditional — this IS the account pipeline
When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode),
],
auth=[
check_scope,
When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership),
When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access),
When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl),
When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
)
external_sso_pipeline = AuthPipeline(
prepare=[
build_external_identity,
When(PATH_HAS_APP_ID, then=load_app),
When(PATH_HAS_APP_ID, then=load_tenant),
When(PATH_HAS_APP_ID, then=resolve_external_user),
When(PATH_HAS_APP_ID, then=load_app_access_mode),
],
auth=[
check_scope,
When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl),
When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission),
],
)
auth_router = PipelineRouter({
TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline),
TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})),
})

View File

@ -0,0 +1,53 @@
from __future__ import annotations
from collections.abc import Callable
from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition
from libs.oauth_bearer import TokenType
from services.enterprise.enterprise_service import WebAppAccessMode
from services.feature_service import FeatureService
CondFn = Callable[[RequestContext, AuthData | None], bool]
class Cond:
def __init__(self, fn: CondFn) -> None:
self._fn = fn
def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
return self._fn(ctx, data)
def __and__(self, other: Cond) -> Cond:
return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data))
def __or__(self, other: Cond) -> Cond:
return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data))
def __invert__(self) -> Cond:
return Cond(lambda ctx, data: not self(ctx, data))
def request_cond(fn: Callable[[RequestContext], bool]) -> Cond:
return Cond(lambda ctx, _: fn(ctx))
def data_cond(fn: Callable[[AuthData], bool]) -> Cond:
return Cond(lambda _, data: data is not None and fn(data))
def config_cond(fn: Callable[[], bool]) -> Cond:
return Cond(lambda _, __: fn())
TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO)
PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params)
EDITION_CE = config_cond(lambda: current_edition() == Edition.CE)
EDITION_EE = config_cond(lambda: current_edition() == Edition.EE)
EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS)
WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled)
LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE)

View File

@ -1,68 +0,0 @@
"""Mutable per-request context for the openapi auth pipeline.
Every field starts None / empty and is filled in by a step. The pipeline
is the only thing that should construct or mutate Context — handlers
read populated values via the decorator's kwargs unpacking.
Context is intentionally decoupled from Flask's ``Request``: the pipeline
guard extracts whatever transport-level inputs the steps need (bearer
token, path params) at the boundary and writes them into Context fields,
so steps stay testable without a request object and won't leak coupling
to a specific framework.
"""
from __future__ import annotations
import uuid
from collections.abc import Mapping
from contextvars import Token
from dataclasses import dataclass, field
from datetime import datetime
from typing import TYPE_CHECKING, Literal, Protocol
from werkzeug.exceptions import Unauthorized
from libs.oauth_bearer import AuthContext, Scope, SubjectType
if TYPE_CHECKING:
from models import App, Tenant
@dataclass
class Context:
required_scope: Scope
bearer_token: str | None = None
path_params: Mapping[str, str] = field(default_factory=dict)
subject_type: SubjectType | None = None
subject_email: str | None = None
subject_issuer: str | None = None
account_id: uuid.UUID | None = None
scopes: frozenset[Scope] = field(default_factory=frozenset)
token_id: uuid.UUID | None = None
token_hash: str | None = None
cached_verified_tenants: dict[str, bool] | None = None
source: str | None = None
expires_at: datetime | None = None
app: App | None = None
tenant: Tenant | None = None
caller: object | None = None
caller_kind: Literal["account", "end_user"] | None = None
auth_ctx_reset_token: Token[AuthContext] | None = None
@property
def must_tenant(self) -> Tenant:
if not self.tenant:
raise Unauthorized("tenant is not associated")
return self.tenant
@property
def must_subject_type(self) -> SubjectType:
if not self.subject_type:
raise Unauthorized("subject_type unset — BearerCheck did not run")
return self.subject_type
class Step(Protocol):
"""One responsibility. Mutate ctx or raise to short-circuit."""
def __call__(self, ctx: Context) -> None: ...

View File

@ -0,0 +1,62 @@
from __future__ import annotations
import uuid
from enum import StrEnum
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from libs.oauth_bearer import Scope, TokenType
from models.account import Account, Tenant
from models.model import App, EndUser
from services.enterprise.enterprise_service import WebAppAccessMode
class Edition(StrEnum):
CE = "ce"
EE = "ee"
SAAS = "saas"
def current_edition() -> Edition:
if dify_config.EDITION == "CLOUD":
return Edition.SAAS
if dify_config.ENTERPRISE_ENABLED:
return Edition.EE
return Edition.CE
class ExternalIdentity(BaseModel):
model_config = ConfigDict(frozen=True)
email: str
issuer: str | None = None
class RequestContext(BaseModel):
model_config = ConfigDict(frozen=True)
token_type: TokenType
scope: Scope | None = None
path_params: dict[str, str]
class AuthData(BaseModel):
model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True)
required_scope: Scope | None = None
token_type: TokenType
account_id: uuid.UUID | None = None
token_hash: str
token_id: uuid.UUID | None = None
scopes: frozenset[Scope]
tenants: dict[str, bool] = Field(default_factory=dict)
external_identity: ExternalIdentity | None = None
app: App | None = None
tenant: Tenant | None = None
app_access_mode: WebAppAccessMode | None = None
caller: Account | EndUser | None = None
caller_kind: Literal["account", "end_user"] | None = None

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from controllers.openapi.auth.conditions import Cond
from controllers.openapi.auth.data import AuthData, RequestContext
class When:
def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None:
self.condition = condition
self._step = then
def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool:
return self.condition(ctx, data)
def __call__(self, arg: Any) -> None:
self._step(arg)

View File

@ -1,51 +1,224 @@
"""Pipeline IS the auth scheme.
"""Auth pipeline — entry point for all openapi auth.
`Pipeline.guard(scope=…)` is the only attachment point for endpoints
that is the design lock-in: forgetting an auth layer is structurally
impossible because there is no "sometimes wrap, sometimes don't" choice.
`PipelineRouter.guard()` is the only attachment point for endpoints.
`AuthPipeline` is a pure step-runner with no routing concerns.
`PipelineRoute` binds a pipeline to optional edition requirements.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from functools import wraps
from typing import Any
from flask import request
from flask import current_app, request
from flask_login import user_logged_in
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.context import Context, Step
from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx
from controllers.openapi._audit import emit_wrong_surface
from controllers.openapi.auth.data import (
AuthData,
Edition,
RequestContext,
current_edition,
)
from controllers.openapi.auth.flow import When
from libs.oauth_bearer import (
AuthContext,
Scope,
TokenType,
extract_bearer,
get_authenticator,
reset_auth_ctx,
set_auth_ctx,
)
from services.feature_service import FeatureService, LicenseStatus
# ---------------------------------------------------------------------------
# New design: AuthPipeline / PipelineRoute / PipelineRouter
# ---------------------------------------------------------------------------
class Pipeline:
def __init__(self, *steps: Step) -> None:
self._steps = steps
class AuthPipeline:
"""Pure step-runner — no routing, no guard.
def run(self, ctx: Context) -> None:
for step in self._steps:
step(ctx)
`prepare` steps receive a mutable builder dict (includes `path_params`).
`auth` steps receive the fully constructed, frozen `AuthData`.
"""
def guard(self, *, scope: Scope):
def decorator(view):
def __init__(self, prepare: list, auth: list) -> None:
self._prepare = prepare
self._auth = auth
def _run(
self,
identity: AuthContext,
args: tuple,
kwargs: dict,
view: Callable,
*,
scope: Scope | None,
) -> Any:
req_ctx = RequestContext(
token_type=identity.token_type,
scope=scope,
path_params=dict(request.view_args or {}),
)
builder = _init_builder(identity, scope)
builder["path_params"] = dict(req_ctx.path_params)
for step in self._prepare:
if _should_run(step, req_ctx, data=None):
step(builder)
builder.pop("path_params", None)
builder.pop("_subject_email", None)
builder.pop("_subject_issuer", None)
data = AuthData(**builder)
for step in self._auth:
if _should_run(step, req_ctx, data=data):
step(data)
reset_token = set_auth_ctx(identity)
if data.caller:
_mount_flask_login(data.caller)
try:
kwargs["auth_data"] = data
return view(*args, **kwargs)
finally:
reset_auth_ctx(reset_token)
@dataclass(frozen=True)
class PipelineRoute:
pipeline: AuthPipeline
required_edition: frozenset[Edition] | None = None
class PipelineRouter:
"""Entry point for openapi auth.
`guard()` is the decorator that endpoints attach to. It applies
global gates (edition, token type) then dispatches to the matching
`PipelineRoute` for the token type.
"""
def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None:
self._routes = routes
def guard(
self,
*,
scope: Scope | None = None,
allowed_token_types: frozenset[TokenType] | None = None,
edition: frozenset[Edition] | None = None,
) -> Callable:
def decorator(view: Callable) -> Callable:
@wraps(view)
def decorated(*args, **kwargs):
# Extract transport-level inputs at the boundary so steps
# stay decoupled from Flask's request object.
ctx = Context(
required_scope=scope,
bearer_token=extract_bearer(request),
path_params=dict(request.view_args or {}),
def decorated(*args: Any, **kwargs: Any) -> Any:
return self._execute(
args,
kwargs,
view,
scope=scope,
allowed_token_types=allowed_token_types,
edition=edition,
)
try:
self.run(ctx)
kwargs.update(
app_model=ctx.app,
caller=ctx.caller,
caller_kind=ctx.caller_kind,
)
return view(*args, **kwargs)
finally:
if ctx.auth_ctx_reset_token is not None:
reset_auth_ctx(ctx.auth_ctx_reset_token)
return decorated
return decorator
def _execute(
self,
args: tuple,
kwargs: dict,
view: Callable,
*,
scope: Scope | None,
allowed_token_types: frozenset[TokenType] | None,
edition: frozenset[Edition] | None,
) -> Any:
# Gate 1: endpoint-level edition (404 — feature doesn't exist here)
if edition is not None and current_edition() not in edition:
raise NotFound()
# Gate 2: EE license for endpoint-level edition requirement
if edition is not None and Edition.EE in edition:
_check_license()
token = extract_bearer(request)
if not token:
raise Unauthorized("bearer required")
identity = get_authenticator().authenticate(token)
# Gate 3: endpoint-level token type allowlist (403)
if allowed_token_types is not None and identity.token_type not in allowed_token_types:
emit_wrong_surface(
subject_type=_subject_type_str(identity),
attempted_path=request.path,
client_id=getattr(identity, "client_id", None),
token_id=str(identity.token_id) if identity.token_id else None,
)
raise Forbidden("unsupported_token_type")
route = self._routes.get(identity.token_type)
if route is None:
raise Forbidden("unsupported_token_type")
# Gate 4: route-level edition invariant (token type requires EE)
if route.required_edition is not None:
if current_edition() not in route.required_edition:
raise Forbidden("external_sso_requires_ee")
if Edition.EE in route.required_edition:
_check_license()
return route.pipeline._run(identity, args, kwargs, view, scope=scope)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool:
if isinstance(step, When):
return step.applies(req_ctx, data)
return True
def _init_builder(identity: AuthContext, scope: Scope | None) -> dict:
return {
"token_type": identity.token_type,
"account_id": identity.account_id,
"token_hash": identity.token_hash,
"token_id": identity.token_id,
"scopes": frozenset(identity.scopes),
"tenants": dict(identity.verified_tenants),
"required_scope": scope,
"_subject_email": identity.subject_email,
"_subject_issuer": identity.subject_issuer,
}
def _subject_type_str(identity: Any) -> str | None:
subject = getattr(identity, "subject_type", None)
if subject is None:
return None
return subject.value if hasattr(subject, "value") else str(subject)
def _check_license() -> None:
settings = FeatureService.get_system_features()
if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}:
raise Forbidden("license_invalid")
def _mount_flask_login(user: Any) -> None:
current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined]
user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined]

View File

@ -0,0 +1,78 @@
from __future__ import annotations
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import ExternalIdentity
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from models.account import TenantStatus
from services.account_service import AccountService, TenantService
from services.app_service import AppService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
def build_external_identity(builder: dict) -> None:
email = builder.pop("_subject_email", None)
issuer = builder.pop("_subject_issuer", None)
if email:
builder["external_identity"] = ExternalIdentity(email=email, issuer=issuer)
def load_app(builder: dict) -> None:
app_id = builder["path_params"]["app_id"]
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
builder["app"] = app
def load_tenant(builder: dict) -> None:
app = builder["app"]
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
builder["tenant"] = tenant
def load_account(builder: dict) -> None:
account = AccountService.get_account_by_id(db.session, str(builder["account_id"]))
if account is None:
raise Unauthorized("account not found")
tenant = builder.get("tenant")
if tenant:
account.current_tenant = tenant
builder["caller"] = account
builder["caller_kind"] = "account"
def resolve_external_user(builder: dict) -> None:
tenant = builder.get("tenant")
app = builder.get("app")
ext: ExternalIdentity | None = builder.get("external_identity")
if not all([tenant, app, ext]):
raise Unauthorized("missing context for external user resolution")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=str(tenant.id), # type: ignore[union-attr]
app_id=str(app.id), # type: ignore[union-attr]
user_id=ext.email, # type: ignore[union-attr]
)
builder["caller"] = end_user
builder["caller_kind"] = "end_user"
def load_app_access_mode(builder: dict) -> None:
app = builder.get("app")
if app is None:
return
try:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app.id))
if settings is None:
builder["app_access_mode"] = None
return
builder["app_access_mode"] = WebAppAccessMode(settings.access_mode)
except ValueError:
builder["app_access_mode"] = None

View File

@ -1,170 +0,0 @@
"""Pipeline steps. Each is one responsibility.
`BearerCheck` is the only step that touches the token registry; downstream
steps see only the populated `Context`. `BearerCheck` also publishes the
resolved identity to the openapi auth ``ContextVar`` (the same one the
decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the
surface gate and any handler reading the request-scoped context has a single
source of truth across both auth-attach paths. The reset token is stashed
on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in
its `finally` so worker-thread reuse can't leak identity across requests.
"""
from __future__ import annotations
from collections.abc import Callable
from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized
from configs import dify_config
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter
from controllers.openapi.auth.surface_gate import check_surface
from extensions.ext_database import db
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
check_workspace_membership,
get_authenticator,
set_auth_ctx,
)
from models import TenantStatus
from services.account_service import TenantService
from services.app_service import AppService
class BearerCheck:
"""Resolve bearer → populate identity fields. Rate-limit is enforced
inside `BearerAuthenticator.authenticate`, so no separate step here.
Also publishes the resolved `AuthContext` via
:func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level
``validate_bearer`` writes — so the surface gate + downstream readers
don't see two different identity sources. The reset token is parked on
``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume."""
def __call__(self, ctx: Context) -> None:
if not ctx.bearer_token:
raise Unauthorized("bearer required")
try:
authn = get_authenticator().authenticate(ctx.bearer_token)
except InvalidBearerError as e:
raise Unauthorized(str(e))
ctx.subject_type = authn.subject_type
ctx.subject_email = authn.subject_email
ctx.subject_issuer = authn.subject_issuer
ctx.account_id = authn.account_id
ctx.scopes = frozenset(authn.scopes)
ctx.source = authn.source
ctx.token_id = authn.token_id
ctx.expires_at = authn.expires_at
ctx.token_hash = authn.token_hash
ctx.cached_verified_tenants = dict(authn.verified_tenants)
ctx.auth_ctx_reset_token = set_auth_ctx(authn)
class ScopeCheck:
"""Verify ctx.scopes (already populated by BearerCheck) covers required."""
def __call__(self, ctx: Context) -> None:
if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes:
return
raise Forbidden("insufficient_scope")
class SurfaceCheck:
"""Reject the request if the resolved subject is not in `accepted`."""
def __init__(self, *, accepted: frozenset[SubjectType]) -> None:
self._accepted = accepted
def __call__(self, ctx: Context) -> None:
check_surface(self._accepted)
class AppResolver:
"""Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant.
Every endpoint using the OAuth bearer pipeline must declare
``<string:app_id>`` in its route — that is the design lock-in (no body /
header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into
``ctx.path_params`` at the boundary so this step doesn't need to know
about the request object.
"""
def __call__(self, ctx: Context) -> None:
app_id = ctx.path_params.get("app_id")
if not app_id:
raise BadRequest("app_id is required in path")
app = AppService.get_app_by_id(db.session, app_id)
if not app or app.status != "normal":
raise NotFound("app not found")
if not app.enable_api:
raise Forbidden("service_api_disabled")
tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id))
if tenant is None or tenant.status == TenantStatus.ARCHIVE:
raise Forbidden("workspace unavailable")
ctx.app, ctx.tenant = app, tenant
class WorkspaceMembershipCheck:
"""Layer 0 — workspace membership gate.
CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers
(dfoa_) only — SSO subjects skip.
"""
def __call__(self, ctx: Context) -> None:
if dify_config.ENTERPRISE_ENABLED:
return
if ctx.subject_type != SubjectType.ACCOUNT:
return
if ctx.account_id is None or ctx.tenant is None:
raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run")
if ctx.token_hash is None:
raise Unauthorized("token_hash unset — BearerCheck did not run")
check_workspace_membership(
account_id=ctx.account_id,
tenant_id=ctx.must_tenant.id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.cached_verified_tenants or {},
)
class AppAuthzCheck:
def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None:
self._resolve = resolve_strategy
def __call__(self, ctx: Context) -> None:
if not self._resolve().authorize(ctx):
raise Forbidden("subject_no_app_access")
class CallerMount:
def __init__(self, *mounters: CallerMounter) -> None:
self._mounters = mounters
def __call__(self, ctx: Context) -> None:
if ctx.subject_type is None:
raise Unauthorized("subject_type unset — BearerCheck did not run")
for m in self._mounters:
if m.applies_to(ctx.must_subject_type):
m.mount(ctx)
return
raise Unauthorized("no caller mounter for subject type")
__all__ = [
"AppAuthzCheck",
"AppResolver",
"AuthContext",
"BearerCheck",
"CallerMount",
"ScopeCheck",
"SurfaceCheck",
"WorkspaceMembershipCheck",
]

View File

@ -1,168 +0,0 @@
"""Strategy classes for the openapi auth pipeline.
App authorization (Acl/Membership) and caller mounting (Account/EndUser)
vary along independent axes; each strategy is one class so the pipeline
composition stays a flat list.
"""
from __future__ import annotations
from typing import Protocol
from flask import current_app
from flask_login import user_logged_in
from controllers.openapi.auth.context import Context
from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.oauth_bearer import SubjectType
from services.account_service import AccountService, TenantService
from services.end_user_service import EndUserService
from services.enterprise.enterprise_service import (
EnterpriseService,
WebAppAccessMode,
)
class AppAuthzStrategy(Protocol):
def authorize(self, ctx: Context) -> bool: ...
class AclStrategy:
"""Per-app ACL, evaluated in two stages.
The EE gateway has already enforced tenancy and workspace membership
by the time this strategy runs, so AclStrategy only owns per-app ACL:
1. Subject vs access-mode compatibility (pure rule table). External-SSO
bearers belong to public-facing apps only; account bearers cover the
full set. A mismatch is an immediate deny — no IO.
2. For modes that pair with the subject, decide whether the inner
permission API must run. Only `PRIVATE` (per-app selected-user list)
requires it; the remaining modes are pass-through.
"""
_ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = {
SubjectType.ACCOUNT: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}
),
SubjectType.EXTERNAL_SSO: frozenset(
{
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}
),
}
_MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE})
def authorize(self, ctx: Context) -> bool:
if ctx.app is None:
return False
access_mode = self._fetch_access_mode(ctx.app.id)
if access_mode is None:
return False
if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode):
return False
if access_mode not in self._MODES_REQUIRING_INNER_CHECK:
return True
return self._inner_permission_check(ctx)
@staticmethod
def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None:
settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id)
if settings is None:
return None
try:
return WebAppAccessMode(settings.access_mode)
except ValueError:
return None
@classmethod
def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool:
return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset())
def _inner_permission_check(self, ctx: Context) -> bool:
if ctx.app is None:
return False
user_id = self._resolve_user_id(ctx)
if user_id is None:
return False
return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id,
app_id=ctx.app.id,
)
@staticmethod
def _resolve_user_id(ctx: Context) -> str | None:
if ctx.subject_type == SubjectType.ACCOUNT:
return str(ctx.account_id) if ctx.account_id is not None else None
if ctx.subject_email is None:
return None
account = AccountService.get_account_by_email(db.session, ctx.subject_email)
return str(account.id) if account is not None else None
class MembershipStrategy:
"""Tenant-membership fallback.
Used when webapp-auth is disabled (CE deployment). Account-bearing
subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is
denied (it requires the webapp-auth surface).
"""
def authorize(self, ctx: Context) -> bool:
if ctx.subject_type == SubjectType.EXTERNAL_SSO:
return False
if ctx.tenant is None:
return False
return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id)
def _login_as(user) -> None:
"""Set Flask-Login request user so downstream services see the caller."""
current_app.login_manager._update_request_context_with_user(user) # type:ignore
user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore
class CallerMounter(Protocol):
def applies_to(self, subject_type: SubjectType) -> bool: ...
def mount(self, ctx: Context) -> None: ...
class AccountMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.ACCOUNT
def mount(self, ctx: Context) -> None:
if ctx.account_id is None:
raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run")
account = AccountService.get_account_by_id(db.session, str(ctx.account_id))
if account is None:
raise RuntimeError("AccountMounter: account row missing for resolved bearer")
account.current_tenant = ctx.must_tenant
_login_as(account)
ctx.caller, ctx.caller_kind = account, "account"
class EndUserMounter:
def applies_to(self, subject_type: SubjectType) -> bool:
return subject_type == SubjectType.EXTERNAL_SSO
def mount(self, ctx: Context) -> None:
if ctx.tenant is None or ctx.app is None or ctx.subject_email is None:
raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run")
end_user = EndUserService.get_or_create_end_user_by_type(
InvokeFrom.OPENAPI,
tenant_id=ctx.tenant.id,
app_id=ctx.app.id,
user_id=ctx.subject_email,
)
_login_as(end_user)
ctx.caller, ctx.caller_kind = end_user, "end_user"

View File

@ -1,89 +0,0 @@
"""Surface gate.
`@accept_subjects(...)` is the route-level form. `SurfaceCheck` (pipeline
step) is the pipeline-level form. Both delegate to `check_surface` so the
audit emit + canonical-path message are single-sourced.
Subjects come from `libs.oauth_bearer.SubjectType` directly — no parallel
vocabulary. Caller hits the wrong surface → 403 ``wrong_surface`` + audit
``openapi.wrong_surface_denied``.
"""
from __future__ import annotations
from collections.abc import Callable
from functools import wraps
from typing import TypeVar
from flask import request
from werkzeug.exceptions import Forbidden
from controllers.openapi._audit import emit_wrong_surface
from libs.oauth_bearer import SubjectType, try_get_auth_ctx
_CANONICAL_PATH: dict[SubjectType, str] = {
SubjectType.ACCOUNT: "/openapi/v1/apps",
SubjectType.EXTERNAL_SSO: "/openapi/v1/permitted-external-apps",
}
F = TypeVar("F", bound=Callable[..., object])
def check_surface(accepted: frozenset[SubjectType]) -> None:
"""Enforce that the resolved subject is in ``accepted``.
Reads the openapi auth ContextVar via :func:`try_get_auth_ctx`. Raises
``Forbidden`` with ``wrong_surface`` + canonical-path hint on miss;
emits ``openapi.wrong_surface_denied`` audit. If no auth context is
set the bearer layer didn't run — that's a wiring bug, not a
user-driven failure, so surface it as a ``RuntimeError`` instead of
a silent 403.
"""
ctx = try_get_auth_ctx()
if ctx is None:
raise RuntimeError(
"check_surface called without an auth context; stack validate_bearer or BearerCheck above the surface gate"
)
subject = _coerce_subject_type(getattr(ctx, "subject_type", None))
if subject in accepted:
return
canonical = _CANONICAL_PATH.get(subject, "/openapi/v1/") if subject else "/openapi/v1/"
emit_wrong_surface(
subject_type=subject.value if subject else None,
attempted_path=request.path,
client_id=getattr(ctx, "client_id", None),
token_id=_stringify(getattr(ctx, "token_id", None)),
)
raise Forbidden(description=f"wrong_surface (canonical: {canonical})")
def accept_subjects(*accepted: SubjectType) -> Callable[[F], F]:
accepted_set: frozenset[SubjectType] = frozenset(accepted)
def deco(fn: F) -> F:
@wraps(fn)
def wrapper(*args: object, **kwargs: object) -> object:
check_surface(accepted_set)
return fn(*args, **kwargs)
return wrapper # type: ignore[return-value]
return deco
def _coerce_subject_type(raw: object) -> SubjectType | None:
if raw is None:
return None
if isinstance(raw, SubjectType):
return raw
if isinstance(raw, str):
return SubjectType(raw)
return None
def _stringify(value: object) -> str | None:
if value is None:
return None
return str(value)

View File

@ -0,0 +1,82 @@
from __future__ import annotations
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import Scope, TokenType, check_workspace_membership
from services.account_service import AccountService, TenantService
from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode
def check_scope(data: AuthData) -> None:
if data.required_scope is None:
return
if Scope.FULL in data.scopes or data.required_scope in data.scopes:
return
raise Forbidden("insufficient_scope")
# reject_external_sso removed — PipelineRouter._execute raises Forbidden("external_sso_requires_ee")
# directly when route.required_edition is not satisfied. Not a pipeline step.
def check_membership(data: AuthData) -> None:
if data.tenant is None:
raise Unauthorized("tenant unset")
check_workspace_membership(
account_id=data.account_id,
tenant_id=data.tenant.id,
token_hash=data.token_hash,
membership_cache=data.tenants,
)
def check_app_access(data: AuthData) -> None:
if data.tenant is None:
return
if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id):
raise Forbidden("subject_no_app_access")
_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = {
TokenType.OAUTH_ACCOUNT: frozenset({
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
WebAppAccessMode.PRIVATE_ALL,
WebAppAccessMode.PRIVATE,
}),
TokenType.OAUTH_EXTERNAL_SSO: frozenset({
WebAppAccessMode.PUBLIC,
WebAppAccessMode.SSO_VERIFIED,
}),
}
def check_acl(data: AuthData) -> None:
if data.app is None or data.app_access_mode is None:
raise Forbidden("app or access mode not loaded")
allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset())
if data.app_access_mode not in allowed_modes:
raise Forbidden("subject_not_allowed_for_access_mode")
def check_private_app_permission(data: AuthData) -> None:
if data.app is None:
raise Forbidden("app not loaded")
user_id = _resolve_user_id(data)
if user_id is None:
raise Forbidden("cannot resolve user for private app check")
if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(
user_id=user_id, app_id=data.app.id
):
raise Forbidden("user_not_allowed_for_private_app")
def _resolve_user_id(data: AuthData) -> str | None:
if data.token_type == TokenType.OAUTH_ACCOUNT:
return str(data.account_id) if data.account_id is not None else None
if data.external_identity is None:
return None
account = AccountService.get_account_by_email(db.session, data.external_identity.email)
return str(account.id) if account is not None else None

View File

@ -17,11 +17,11 @@ from controllers.common.errors import (
UnsupportedFileTypeError,
)
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from fields.file_fields import FileResponse
from libs.oauth_bearer import Scope
from models import Account, App
from services.file_service import FileService
@ -39,8 +39,11 @@ class AppFileUploadApi(Resource):
}
)
@openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__])
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
if "file" not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:

View File

@ -17,7 +17,8 @@ from werkzeug.exceptions import BadRequest, NotFound
from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values
from controllers.common.schema import register_schema_models
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface
from extensions.ext_database import db
from libs.helper import to_timestamp
@ -55,8 +56,11 @@ def _ensure_form_is_allowed_for_openapi(form) -> None:
@openapi_ns.route("/apps/<string:app_id>/form/human_input/<string:form_token>")
class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.response(200, "Form definition")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, form_token: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
@ -69,8 +73,11 @@ class OpenApiWorkflowHumanInputFormApi(Resource):
@openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__])
@openapi_ns.response(200, "Form submitted")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def post(self, app_id: str, form_token: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {})
service = HumanInputService(db.engine)

View File

@ -17,7 +17,8 @@ from sqlalchemy.orm import sessionmaker
from werkzeug.exceptions import NotFound, UnprocessableEntity
from controllers.openapi import openapi_ns
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
@ -28,7 +29,7 @@ from core.workflow.human_input_policy import HumanInputSurface
from extensions.ext_database import db
from libs.oauth_bearer import Scope
from models.enums import CreatorUserRole
from models.model import App, AppMode
from models.model import AppMode
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
@ -36,8 +37,11 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream
@openapi_ns.route("/apps/<string:app_id>/tasks/<string:task_id>/events")
class OpenApiWorkflowEventsApi(Resource):
@openapi_ns.response(200, "SSE event stream")
@OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str):
@auth_router.guard(scope=Scope.APPS_RUN)
def get(self, app_id: str, task_id: str, *, auth_data: AuthData):
app_model = auth_data.app
caller = auth_data.caller
caller_kind = auth_data.caller_kind
app_mode = AppMode.value_of(app_model.mode)
if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}:
raise UnprocessableEntity("mode_not_supported_for_event_reconnect")

View File

@ -15,14 +15,10 @@ from werkzeug.exceptions import NotFound
from controllers.openapi import openapi_ns
from controllers.openapi._models import WorkspaceDetailResponse, WorkspaceListResponse, WorkspaceSummaryResponse
from controllers.openapi.auth.surface_gate import accept_subjects
from controllers.openapi.auth.composition import auth_router
from controllers.openapi.auth.data import AuthData
from extensions.ext_database import db
from libs.oauth_bearer import (
ACCEPT_USER_ANY,
SubjectType,
get_auth_ctx,
validate_bearer,
)
from libs.oauth_bearer import Scope, TokenType
from models import Tenant, TenantAccountJoin
from services.account_service import TenantService
@ -30,12 +26,9 @@ from services.account_service import TenantService
@openapi_ns.route("/workspaces")
class WorkspacesApi(Resource):
@openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self):
ctx = get_auth_ctx()
rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id))
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, *, auth_data: AuthData):
rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id))
return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200
@ -43,12 +36,9 @@ class WorkspacesApi(Resource):
@openapi_ns.route("/workspaces/<string:workspace_id>")
class WorkspaceByIdApi(Resource):
@openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__])
@validate_bearer(accept=ACCEPT_USER_ANY)
@accept_subjects(SubjectType.ACCOUNT)
def get(self, workspace_id: str):
ctx = get_auth_ctx()
row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id)
@auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def get(self, workspace_id: str, *, auth_data: AuthData):
row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id)
# 404 (not 403) on non-member so workspace IDs don't leak across tenants.
if row is None:
raise NotFound("workspace not found")

View File

@ -43,6 +43,11 @@ class SubjectType(StrEnum):
EXTERNAL_SSO = "external_sso"
class TokenType(StrEnum):
OAUTH_ACCOUNT = "oauth_account"
OAUTH_EXTERNAL_SSO = "oauth_external_sso"
class Scope(StrEnum):
"""Catalog of bearer scopes recognised by the openapi surface.
@ -55,6 +60,7 @@ class Scope(StrEnum):
APPS_READ = "apps:read"
APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external"
APPS_RUN = "apps:run"
WORKSPACE_READ = "workspace:read"
class Accepts(StrEnum):
@ -77,7 +83,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = {
class AuthContext:
"""Per-request identity published via :data:`_auth_ctx_var`
(see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` /
``subject_type`` / ``source`` come from the TokenKind, not the DB —
``subject_type`` / ``token_type`` come from the TokenKind, not the DB —
corrupt rows can't elevate scope.
`verified_tenants` is a snapshot of the Layer-0 verdict cache at
@ -92,7 +98,7 @@ class AuthContext:
client_id: str | None
scopes: frozenset[Scope]
token_id: uuid.UUID
source: str
token_type: TokenType
expires_at: datetime | None
token_hash: str
verified_tenants: dict[str, bool] = field(default_factory=dict)
@ -180,7 +186,7 @@ class TokenKind:
prefix: str
subject_type: SubjectType
scopes: frozenset[Scope]
source: str
token_type: TokenType
resolver: Resolver
def matches(self, token: str) -> bool:
@ -291,7 +297,7 @@ class BearerAuthenticator:
client_id=row.client_id,
scopes=kind.scopes,
token_id=row.token_id,
source=kind.source,
token_type=kind.token_type,
expires_at=row.expires_at,
token_hash=token_hash,
verified_tenants=dict(row.verified_tenants),
@ -483,7 +489,7 @@ def check_workspace_membership(
account_id: uuid.UUID | str,
tenant_id: str,
token_hash: str,
cached_verdicts: dict[str, bool],
membership_cache: dict[str, bool],
) -> None:
"""Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow.
@ -492,7 +498,7 @@ def check_workspace_membership(
short-circuiting on EE / SSO subjects before invoking — this function
runs the membership + active-status checks unconditionally.
"""
cached = cached_verdicts.get(tenant_id)
cached = membership_cache.get(tenant_id)
if cached is True:
return
if cached is False:
@ -530,7 +536,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None:
account_id=ctx.account_id,
tenant_id=tenant_id,
token_hash=ctx.token_hash,
cached_verdicts=ctx.verified_tenants,
membership_cache=ctx.verified_tenants,
)
@ -664,14 +670,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry:
prefix=account.prefix,
subject_type=account.subject_type,
scopes=account.scopes,
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT,
resolver=oauth.for_account(),
),
TokenKind(
prefix=external.prefix,
subject_type=external.subject_type,
scopes=external.scopes,
source="oauth_external_sso",
token_type=TokenType.OAUTH_EXTERNAL_SSO,
resolver=oauth.for_external_sso(),
),
]

View File

@ -1,66 +1,80 @@
from unittest.mock import patch
from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.steps import (
AppAuthzCheck,
AppResolver,
BearerCheck,
CallerMount,
ScopeCheck,
SurfaceCheck,
WorkspaceMembershipCheck,
)
from controllers.openapi.auth.strategies import (
AccountMounter,
AclStrategy,
EndUserMounter,
MembershipStrategy,
)
from libs.oauth_bearer import SubjectType
from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline
from controllers.openapi.auth.flow import When
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from libs.oauth_bearer import TokenType
def test_pipeline_is_composed():
assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline)
def test_account_pipeline_is_auth_pipeline():
assert isinstance(account_pipeline, AuthPipeline)
def test_pipeline_step_order():
"""BearerCheck → SurfaceCheck → ScopeCheck → AppResolver →
WorkspaceMembershipCheck → AppAuthzCheck → CallerMount.
SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits
`openapi.wrong_surface_denied`. Rate-limit is enforced inside
`BearerAuthenticator.authenticate`, not as a separate pipeline step."""
steps = OAUTH_BEARER_PIPELINE._steps
assert isinstance(steps[0], BearerCheck)
assert isinstance(steps[1], SurfaceCheck)
assert isinstance(steps[2], ScopeCheck)
assert isinstance(steps[3], AppResolver)
assert isinstance(steps[4], WorkspaceMembershipCheck)
assert isinstance(steps[5], AppAuthzCheck)
assert isinstance(steps[6], CallerMount)
def test_external_sso_pipeline_is_auth_pipeline():
assert isinstance(external_sso_pipeline, AuthPipeline)
def test_pipeline_surface_check_accepts_account_only():
"""Current pipeline serves /apps/<id>/run — account surface only."""
surface = OAUTH_BEARER_PIPELINE._steps[1]
assert isinstance(surface, SurfaceCheck)
assert surface._accepted == frozenset({SubjectType.ACCOUNT})
def test_auth_router_is_pipeline_router():
assert isinstance(auth_router, PipelineRouter)
def test_caller_mount_has_both_mounters():
cm = OAUTH_BEARER_PIPELINE._steps[6]
kinds = {type(m) for m in cm._mounters}
assert AccountMounter in kinds
assert EndUserMounter in kinds
def test_account_pipeline_prepare_has_four_entries():
assert len(account_pipeline._prepare) == 4
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_acl_when_enabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = True
assert isinstance(_resolve_app_authz_strategy(), AclStrategy)
def test_account_auth_list_has_five_entries():
assert len(account_pipeline._auth) == 5
@patch("controllers.openapi.auth.composition.FeatureService")
def test_strategy_resolver_picks_membership_when_disabled(fs):
fs.get_system_features.return_value.webapp_auth.enabled = False
assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy)
def test_external_sso_pipeline_prepare_has_five_entries():
assert len(external_sso_pipeline._prepare) == 5
def test_external_sso_auth_list_has_three_entries():
# check_scope (unconditional) + 2 When entries
assert len(external_sso_pipeline._auth) == 3
def test_account_pipeline_has_unconditional_load_account():
# load_account is the only bare (non-When) entry in account prepare
non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)]
assert len(non_when) == 1
def test_external_sso_pipeline_first_prepare_is_build_external_identity():
from controllers.openapi.auth.prepare import build_external_identity
assert external_sso_pipeline._prepare[0] is build_external_identity
def test_external_sso_pipeline_remaining_prepare_entries_are_when():
assert all(isinstance(s, When) for s in external_sso_pipeline._prepare[1:])
def test_first_auth_entry_is_check_scope_in_both_pipelines():
# check_scope is unconditional (not a When) and comes first in auth
assert not isinstance(account_pipeline._auth[0], When)
assert not isinstance(external_sso_pipeline._auth[0], When)
def test_remaining_auth_entries_are_when_for_account():
assert all(isinstance(s, When) for s in account_pipeline._auth[1:])
def test_remaining_auth_entries_are_when_for_external_sso():
assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:])
def test_router_routes_contain_both_token_types():
assert TokenType.OAUTH_ACCOUNT in auth_router._routes
assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes
def test_external_sso_route_has_ee_required_edition():
route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO]
assert isinstance(route, PipelineRoute)
from controllers.openapi.auth.data import Edition
assert route.required_edition == frozenset({Edition.EE})
def test_account_route_has_no_required_edition():
route = auth_router._routes[TokenType.OAUTH_ACCOUNT]
assert isinstance(route, PipelineRoute)
assert route.required_edition is None

View File

@ -0,0 +1,149 @@
from unittest.mock import MagicMock, patch
from controllers.openapi.auth.conditions import (
EDITION_CE,
EDITION_EE,
EDITION_SAAS,
LOADED_APP_IS_PRIVATE,
PATH_HAS_APP_ID,
TOKEN_IS_OAUTH_ACCOUNT,
TOKEN_IS_OAUTH_EXTERNAL_SSO,
WEBAPP_AUTH_ENABLED,
Cond,
config_cond,
data_cond,
request_cond,
)
from controllers.openapi.auth.data import AuthData, Edition, RequestContext
from libs.oauth_bearer import TokenType
from services.enterprise.enterprise_service import WebAppAccessMode
def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None):
return RequestContext(
token_type=token_type,
path_params=path_params or {},
)
def _data(**kwargs):
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()}
defaults.update(kwargs)
return AuthData(**defaults)
# --- Cond operators ---
def test_and_both_true():
a = Cond(lambda ctx, _: True)
b = Cond(lambda ctx, _: True)
assert (a & b)(_ctx()) is True
def test_and_one_false():
a = Cond(lambda ctx, _: True)
b = Cond(lambda ctx, _: False)
assert (a & b)(_ctx()) is False
def test_or_one_true():
a = Cond(lambda ctx, _: False)
b = Cond(lambda ctx, _: True)
assert (a | b)(_ctx()) is True
def test_or_both_false():
a = Cond(lambda ctx, _: False)
b = Cond(lambda ctx, _: False)
assert (a | b)(_ctx()) is False
def test_invert():
a = Cond(lambda ctx, _: True)
assert (~a)(_ctx()) is False
def test_chain_and_or():
always_true = Cond(lambda ctx, _: True)
always_false = Cond(lambda ctx, _: False)
assert ((always_true | always_false) & always_true)(_ctx()) is True
# --- Helper constructors ---
def test_request_cond_ignores_data():
c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT)
assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True
assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
def test_data_cond_returns_false_when_data_none():
c = data_cond(lambda data: True)
assert c(_ctx(), None) is False
def test_data_cond_evaluates_when_data_present():
c = data_cond(lambda data: data.token_hash == "secret")
assert c(_ctx(), _data(token_hash="secret")) is True
assert c(_ctx(), _data(token_hash="other")) is False
def test_config_cond_ignores_ctx_and_data():
c = config_cond(lambda: True)
assert c(_ctx()) is True
c2 = config_cond(lambda: False)
assert c2(_ctx(), _data()) is False
# --- Pre-built conditions ---
def test_token_is_oauth_account():
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True
assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False
def test_token_is_oauth_external_sso():
assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True
def test_path_has_app_id_true():
assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True
def test_path_has_app_id_false():
assert PATH_HAS_APP_ID(_ctx(path_params={})) is False
def test_edition_ce():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE):
assert EDITION_CE(_ctx()) is True
assert EDITION_EE(_ctx()) is False
assert EDITION_SAAS(_ctx()) is False
def test_edition_ee():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE):
assert EDITION_EE(_ctx()) is True
assert EDITION_CE(_ctx()) is False
def test_edition_saas():
with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS):
assert EDITION_SAAS(_ctx()) is True
def test_webapp_auth_enabled():
mock_features = MagicMock()
mock_features.webapp_auth.enabled = True
with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features):
assert WEBAPP_AUTH_ENABLED(_ctx()) is True
def test_loaded_app_is_private():
data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE)
data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC)
data_none = _data(app_access_mode=None)
assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True
assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False
assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False

View File

@ -1,21 +0,0 @@
from controllers.openapi.auth.context import Context
def test_context_starts_unpopulated():
ctx = Context(required_scope="apps:run")
assert ctx.bearer_token is None
assert ctx.path_params == {}
assert ctx.subject_type is None
assert ctx.subject_email is None
assert ctx.account_id is None
assert ctx.scopes == frozenset()
assert ctx.app is None
assert ctx.tenant is None
assert ctx.caller is None
assert ctx.caller_kind is None
def test_context_fields_are_mutable():
ctx = Context(required_scope="apps:run")
ctx.scopes = frozenset({"full"})
assert "full" in ctx.scopes

View File

@ -0,0 +1,116 @@
import uuid
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from controllers.openapi.auth.data import (
AuthData,
Edition,
ExternalIdentity,
RequestContext,
current_edition,
)
from libs.oauth_bearer import Scope, TokenType
# --- Edition / current_edition ---
def test_current_edition_saas():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "CLOUD"
cfg.ENTERPRISE_ENABLED = True
assert current_edition() == Edition.SAAS
def test_current_edition_ee():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "SELF_HOSTED"
cfg.ENTERPRISE_ENABLED = True
assert current_edition() == Edition.EE
def test_current_edition_ce():
with patch("controllers.openapi.auth.data.dify_config") as cfg:
cfg.EDITION = "SELF_HOSTED"
cfg.ENTERPRISE_ENABLED = False
assert current_edition() == Edition.CE
# --- ExternalIdentity ---
def test_external_identity_frozen():
ei = ExternalIdentity(email="a@b.com", issuer="idp")
with pytest.raises(ValidationError):
ei.email = "other@b.com" # type: ignore[misc]
def test_external_identity_issuer_optional():
ei = ExternalIdentity(email="a@b.com")
assert ei.issuer is None
# --- RequestContext ---
def test_request_context_frozen():
ctx = RequestContext(
token_type=TokenType.OAUTH_ACCOUNT,
path_params={"app_id": "123"},
)
with pytest.raises(ValidationError):
ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
def test_request_context_scope_optional():
ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
assert ctx.scope is None
# --- AuthData ---
def test_auth_data_frozen():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset({Scope.FULL}),
)
with pytest.raises(ValidationError):
data.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc]
def test_auth_data_account_id_optional():
data = AuthData(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
token_hash="abc",
scopes=frozenset({Scope.APPS_RUN}),
external_identity=ExternalIdentity(email="u@sso.com"),
)
assert data.account_id is None
def test_auth_data_external_identity_none_for_account():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="abc",
scopes=frozenset({Scope.FULL}),
)
assert data.external_identity is None
def test_auth_data_tenants_default_empty():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.tenants == {}
def test_auth_data_token_id_optional():
data = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
token_hash="abc",
scopes=frozenset(),
)
assert data.token_id is None

View File

@ -0,0 +1,42 @@
import inspect
from controllers.openapi.auth.conditions import Cond
from controllers.openapi.auth.data import AuthData, RequestContext
from controllers.openapi.auth.flow import When
from libs.oauth_bearer import TokenType
def _ctx():
return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={})
def _data():
return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset())
def test_applies_returns_true_when_condition_true():
w = When(Cond(lambda ctx, _: True), then=lambda b: None)
assert w.applies(_ctx()) is True
def test_applies_returns_false_when_condition_false():
w = When(Cond(lambda ctx, _: False), then=lambda b: None)
assert w.applies(_ctx()) is False
def test_applies_with_data():
w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None)
assert w.applies(_ctx(), _data()) is True
assert w.applies(_ctx(), None) is False
def test_call_invokes_step():
calls = []
w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg))
w("payload")
assert calls == ["payload"]
def test_then_is_keyword_only():
sig = inspect.signature(When.__init__)
assert sig.parameters["then"].kind.name == "KEYWORD_ONLY"

View File

@ -1,59 +1,204 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.data import AuthData, Edition
from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter
from libs.oauth_bearer import Scope, TokenType
def test_run_invokes_each_step_in_order():
calls = []
class S:
def __init__(self, tag):
self.tag = tag
def __call__(self, ctx):
calls.append(self.tag)
Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x"))
assert calls == ["a", "b", "c"]
def _make_identity(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=None,
scopes=None,
token_hash="testhash",
subject_email=None,
subject_issuer=None,
verified_tenants=None,
token_id=None,
):
identity = MagicMock()
identity.token_type = token_type
identity.account_id = account_id or uuid.uuid4()
identity.scopes = scopes or frozenset({Scope.FULL})
identity.token_hash = token_hash
identity.subject_email = subject_email
identity.subject_issuer = subject_issuer
identity.verified_tenants = verified_tenants or {}
identity.token_id = token_id or uuid.uuid4()
return identity
def test_run_short_circuits_on_raise():
calls = []
class Boom:
def __call__(self, ctx):
raise RuntimeError("boom")
class Tail:
def __call__(self, ctx):
calls.append("ran")
with pytest.raises(RuntimeError):
Pipeline(Boom(), Tail()).run(Context(required_scope="x"))
assert calls == []
@pytest.fixture
def app():
return Flask(__name__)
def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs():
seen = {}
def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None):
pipeline = AuthPipeline(prepare=prepare or [], auth=auth or [])
return PipelineRouter({token_type: PipelineRoute(pipeline)})
class FakeStep:
def __call__(self, ctx):
ctx.app = "APP"
ctx.caller = "CALLER"
ctx.caller_kind = "account"
pipeline = Pipeline(FakeStep())
def _fake_identity():
return _make_identity()
@pipeline.guard(scope="apps:run")
def handler(app_model, caller, caller_kind):
seen["app_model"] = app_model
seen["caller"] = caller
seen["caller_kind"] = caller_kind
return "ok"
app = Flask(__name__)
with app.test_request_context("/x", method="POST"):
assert handler() == "ok"
assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"}
# --- PipelineRouter.guard ---
def test_guard_passes_auth_data_to_view(app):
router = _make_router()
received = {}
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
received["data"] = auth_data
view()
assert isinstance(received["data"], AuthData)
def test_guard_edition_gate_returns_404(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
@router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE}))
def view(*, auth_data):
pass
with pytest.raises(NotFound):
view()
def test_guard_token_type_gate_returns_403(app):
router = _make_router()
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.emit_wrong_surface"), \
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT}))
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_unregistered_token_type_returns_403(app):
# Router has only OAUTH_ACCOUNT; present OAUTH_EXTERNAL_SSO without allowed_token_types gate
router = _make_router(token_type=TokenType.OAUTH_ACCOUNT)
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
identity = _fake_identity()
identity.token_type = TokenType.OAUTH_EXTERNAL_SSO
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()
def test_guard_no_bearer_returns_401(app):
router = _make_router()
with app.test_request_context("/test"):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None):
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
with pytest.raises(Unauthorized):
view()
def test_guard_runs_prepare_steps_in_order(app):
order = []
def p1(b):
order.append("p1")
def p2(b):
order.append("p2")
router = _make_router(prepare=[p1, p2])
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), \
patch("controllers.openapi.auth.pipeline.reset_auth_ctx"):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
pass
view()
assert order == ["p1", "p2"]
def test_guard_resets_auth_ctx_on_exception(app):
router = _make_router()
reset_called = []
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)):
mock_auth.return_value.authenticate.return_value = _fake_identity()
@router.guard(scope=Scope.FULL)
def view(*, auth_data):
raise RuntimeError("boom")
with pytest.raises(RuntimeError):
view()
assert reset_called == ["tok"]
def test_router_rejects_token_type_on_wrong_edition(app):
pipeline = AuthPipeline(prepare=[], auth=[])
route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE}))
router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route})
with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}):
with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), \
patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, \
patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE):
identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO)
mock_auth.return_value.authenticate.return_value = identity
@router.guard(scope=Scope.APPS_RUN)
def view(*, auth_data):
pass
with pytest.raises(Forbidden):
view()

View File

@ -0,0 +1,187 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from controllers.openapi.auth.data import ExternalIdentity
from controllers.openapi.auth.prepare import (
build_external_identity,
load_account,
load_app,
load_app_access_mode,
load_tenant,
resolve_external_user,
)
# --- load_app ---
def test_load_app_writes_app_to_builder():
app = MagicMock()
app.status = "normal"
app.enable_api = True
builder = {"path_params": {"app_id": "abc"}}
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
load_app(builder)
assert builder["app"] is app
def test_load_app_raises_not_found_when_missing():
builder = {"path_params": {"app_id": "missing"}}
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None):
with pytest.raises(NotFound):
load_app(builder)
def test_load_app_raises_not_found_when_not_normal():
app = MagicMock()
app.status = "archived"
builder = {"path_params": {"app_id": "abc"}}
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
with pytest.raises(NotFound):
load_app(builder)
def test_load_app_raises_forbidden_when_api_disabled():
app = MagicMock()
app.status = "normal"
app.enable_api = False
builder = {"path_params": {"app_id": "abc"}}
with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app):
with pytest.raises(Forbidden):
load_app(builder)
# --- load_tenant ---
def test_load_tenant_writes_tenant():
app = MagicMock()
app.tenant_id = uuid.uuid4()
tenant = MagicMock()
tenant.status = "normal"
builder = {"app": app}
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
load_tenant(builder)
assert builder["tenant"] is tenant
def test_load_tenant_raises_forbidden_when_archived():
from models.account import TenantStatus
app = MagicMock()
app.tenant_id = uuid.uuid4()
tenant = MagicMock()
tenant.status = TenantStatus.ARCHIVE
builder = {"app": app}
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant):
with pytest.raises(Forbidden):
load_tenant(builder)
def test_load_tenant_raises_forbidden_when_missing():
app = MagicMock()
app.tenant_id = uuid.uuid4()
builder = {"app": app}
with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None):
with pytest.raises(Forbidden):
load_tenant(builder)
# --- load_account ---
def test_load_account_writes_caller():
account = MagicMock()
account_id = uuid.uuid4()
builder = {"account_id": account_id}
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
load_account(builder)
assert builder["caller"] is account
assert builder["caller_kind"] == "account"
def test_load_account_sets_current_tenant_when_tenant_present():
account = MagicMock()
tenant = MagicMock()
builder = {"account_id": uuid.uuid4(), "tenant": tenant}
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account):
load_account(builder)
assert account.current_tenant is tenant
def test_load_account_raises_unauthorized_when_not_found():
builder = {"account_id": uuid.uuid4()}
with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None):
with pytest.raises(Unauthorized):
load_account(builder)
# --- resolve_external_user ---
def test_resolve_external_user_writes_caller():
tenant = MagicMock()
app = MagicMock()
end_user = MagicMock()
ext = ExternalIdentity(email="user@sso.com")
builder = {"tenant": tenant, "app": app, "external_identity": ext}
with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user):
resolve_external_user(builder)
assert builder["caller"] is end_user
assert builder["caller_kind"] == "end_user"
def test_resolve_external_user_raises_unauthorized_when_context_missing():
builder = {"tenant": None, "app": MagicMock(), "external_identity": ExternalIdentity(email="u@s.com")}
with pytest.raises(Unauthorized):
resolve_external_user(builder)
# --- load_app_access_mode ---
def test_load_app_access_mode_writes_mode():
from services.enterprise.enterprise_service import WebAppAccessMode
app = MagicMock()
app.id = "app-1"
settings = MagicMock()
settings.access_mode = "public"
builder = {"app": app}
with patch(
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
return_value=settings,
):
load_app_access_mode(builder)
assert builder["app_access_mode"] == WebAppAccessMode.PUBLIC
def test_load_app_access_mode_writes_none_when_value_error():
app = MagicMock()
app.id = "app-1"
builder = {"app": app}
with patch(
"controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id",
side_effect=ValueError("No data found."),
):
load_app_access_mode(builder)
assert builder["app_access_mode"] is None
def test_load_app_access_mode_no_op_when_app_missing():
builder = {}
load_app_access_mode(builder)
assert "app_access_mode" not in builder
# --- build_external_identity ---
def test_build_external_identity_constructs_from_builder_keys():
from controllers.openapi.auth.data import ExternalIdentity
builder = {"_subject_email": "u@sso.com", "_subject_issuer": "idp"}
build_external_identity(builder)
assert isinstance(builder["external_identity"], ExternalIdentity)
assert builder["external_identity"].email == "u@sso.com"
assert "_subject_email" not in builder
def test_build_external_identity_no_op_when_email_missing():
builder = {"_subject_email": None, "_subject_issuer": None}
build_external_identity(builder)
assert "external_identity" not in builder

View File

@ -1,64 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppResolver
from models import TenantStatus
def _ctx(path_params: dict[str, str] | None) -> Context:
return Context(required_scope="apps:run", path_params=path_params or {})
def _app(*, status="normal", enable_api=True):
return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api)
def _tenant(*, status=TenantStatus.NORMAL):
return SimpleNamespace(id="t1", status=status)
def test_resolver_rejects_missing_path_param():
with pytest.raises(BadRequest):
AppResolver()(_ctx({}))
def test_resolver_rejects_empty_path_params():
# `Pipeline.guard` always seeds an empty dict when Flask reports no
# view args, so a missing `app_id` key surfaces here as BadRequest.
with pytest.raises(BadRequest):
AppResolver()(_ctx(None))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_404_when_app_missing(db):
db.session.get.side_effect = [None]
with pytest.raises(NotFound):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_disabled(db):
db.session.get.side_effect = [_app(enable_api=False)]
with pytest.raises(Forbidden) as exc:
AppResolver()(_ctx({"app_id": "x"}))
assert "service_api_disabled" in str(exc.value.description)
@patch("controllers.openapi.auth.steps.db")
def test_resolver_403_when_tenant_archived(db):
db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)]
with pytest.raises(Forbidden):
AppResolver()(_ctx({"app_id": "x"}))
@patch("controllers.openapi.auth.steps.db")
def test_resolver_populates_app_and_tenant(db):
db.session.get.side_effect = [_app(), _tenant()]
ctx = _ctx({"app_id": "x"})
AppResolver()(ctx)
assert ctx.app.id == "app1"
assert ctx.tenant.id == "t1"

View File

@ -1,76 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import AppAuthzCheck
from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id="acc1"):
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.subject_email = "alice@example.com"
c.account_id = account_id
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_private_calls_inner_api(ent):
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private")
ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True
assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with(
user_id="acc1",
app_id="app1",
)
@pytest.mark.parametrize(
("access_mode", "subject_type", "expected"),
[
("public", SubjectType.ACCOUNT, True),
("public", SubjectType.EXTERNAL_SSO, True),
("sso_verified", SubjectType.ACCOUNT, True),
("sso_verified", SubjectType.EXTERNAL_SSO, True),
("private_all", SubjectType.ACCOUNT, True),
("private_all", SubjectType.EXTERNAL_SSO, False),
("private", SubjectType.EXTERNAL_SSO, False),
],
)
@patch("controllers.openapi.auth.strategies.EnterpriseService")
def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected):
"""Step 1 matrix: subject vs access-mode compatibility. No inner API call expected."""
ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode)
account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None
assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected
ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called()
@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant")
@patch("controllers.openapi.auth.strategies.db")
def test_membership_strategy_uses_join_lookup(db_mock, member):
member.return_value = True
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True
member.assert_called_once_with(db_mock.session, "acc1", "t1")
def test_membership_strategy_rejects_external_sso():
assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False
def test_app_authz_check_raises_when_strategy_denies():
deny = SimpleNamespace(authorize=lambda c: False)
with pytest.raises(Forbidden) as exc:
AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT))
assert "subject_no_app_access" in str(exc.value.description)
def test_app_authz_check_passes_when_strategy_allows():
allow = SimpleNamespace(authorize=lambda c: True)
AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -1,83 +0,0 @@
import uuid
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from flask import Flask
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import BearerCheck
from libs.oauth_bearer import (
AuthContext,
InvalidBearerError,
Scope,
SubjectType,
reset_auth_ctx,
try_get_auth_ctx,
)
def _ctx(bearer_token: str | None) -> Context:
return Context(required_scope="apps:run", bearer_token=bearer_token)
def test_bearer_check_rejects_missing_header():
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx(None))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_rejects_unknown_prefix(get_auth):
get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer")
app = Flask(__name__)
with app.test_request_context(), pytest.raises(Unauthorized):
BearerCheck()(_ctx("xxx_abc"))
@patch("controllers.openapi.auth.steps.get_authenticator")
def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth):
tok_id = uuid.uuid4()
authn = AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="a@x.com",
subject_issuer=None,
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=tok_id,
source="oauth-account",
expires_at=datetime.now(UTC),
token_hash="hash-1",
verified_tenants={},
)
get_auth.return_value.authenticate.return_value = authn
app = Flask(__name__)
ctx = _ctx("dfoa_abc")
with app.test_request_context():
BearerCheck()(ctx)
try:
assert ctx.subject_type == SubjectType.ACCOUNT
assert ctx.subject_email == "a@x.com"
assert ctx.scopes == frozenset({Scope.FULL})
assert ctx.source == "oauth-account"
assert ctx.token_id == tok_id
assert ctx.token_hash == "hash-1"
# BearerCheck must also publish the same identity on the
# openapi auth ContextVar so the surface gate + downstream
# handlers don't see two different identity sources between
# the decorator + pipeline paths. The reset token is parked
# on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to
# consume in its `finally`.
published = try_get_auth_ctx()
assert published is authn
assert published.client_id == "difyctl"
assert ctx.auth_ctx_reset_token is not None
finally:
# In production `Pipeline.guard` resets the ContextVar; in
# this isolated step-level test we reset it ourselves so the
# value doesn't leak into the next test on the same worker.
assert ctx.auth_ctx_reset_token is not None
reset_auth_ctx(ctx.auth_ctx_reset_token)

View File

@ -1,157 +0,0 @@
"""Unit tests for WorkspaceMembershipCheck (Layer 0)."""
from __future__ import annotations
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import WorkspaceMembershipCheck
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context:
c = Context(required_scope="apps:read")
c.subject_type = subject_type
c.account_id = account_id
c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None
c.cached_verified_tenants = cached_verified_tenants
c.token_hash = token_hash
return c
@pytest.fixture
def step():
return WorkspaceMembershipCheck()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = True
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id=str(uuid.uuid4()),
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.EXTERNAL_SSO,
account_id=None,
tenant_id=str(uuid.uuid4()),
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": True},
token_hash="hash-1",
)
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={"t1": False},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_db.session.execute.assert_not_called()
mock_record.assert_not_called()
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.return_value.scalar_one_or_none.return_value = None
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="banned")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
with pytest.raises(Forbidden, match="workspace_membership_revoked"):
step(ctx)
mock_record.assert_called_once_with("hash-1", "t1", False)
@patch("controllers.openapi.auth.steps.dify_config")
@patch("libs.oauth_bearer.record_layer0_verdict")
@patch("libs.oauth_bearer.db")
def test_allows_active_member(mock_db, mock_record, mock_cfg, step):
mock_cfg.ENTERPRISE_ENABLED = False
mock_db.session.execute.side_effect = [
MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")),
MagicMock(scalar_one_or_none=MagicMock(return_value="active")),
]
ctx = _ctx(
subject_type=SubjectType.ACCOUNT,
account_id="a1",
tenant_id="t1",
cached_verified_tenants={},
token_hash="hash-1",
)
step(ctx) # no raise
mock_record.assert_called_once_with("hash-1", "t1", True)

View File

@ -1,77 +0,0 @@
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from werkzeug.exceptions import Unauthorized
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import CallerMount
from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter
from core.app.entities.app_invoke_entities import InvokeFrom
from libs.oauth_bearer import SubjectType
def _ctx(*, subject_type, account_id=None, subject_email=None):
c = Context(required_scope="apps:run")
c.subject_type = subject_type
c.account_id = account_id
c.subject_email = subject_email
c.app = SimpleNamespace(id="app1")
c.tenant = SimpleNamespace(id="t1")
return c
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.db")
def test_account_mounter(db, login):
account = SimpleNamespace()
db.session.get.return_value = account
ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1")
AccountMounter().mount(ctx)
assert ctx.caller is account
assert ctx.caller.current_tenant is ctx.tenant
assert ctx.caller_kind == "account"
login.assert_called_once_with(account)
@patch("controllers.openapi.auth.strategies._login_as")
@patch("controllers.openapi.auth.strategies.EndUserService")
def test_end_user_mounter(svc, login):
eu = SimpleNamespace()
svc.get_or_create_end_user_by_type.return_value = eu
ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com")
EndUserMounter().mount(ctx)
svc.get_or_create_end_user_by_type.assert_called_once_with(
InvokeFrom.OPENAPI,
tenant_id="t1",
app_id="app1",
user_id="a@x.com",
)
assert ctx.caller is eu
assert ctx.caller_kind == "end_user"
def test_caller_mount_dispatches_by_subject_type():
seen = {}
class Fake:
def __init__(self, st, tag):
self._st, self._tag = st, tag
def applies_to(self, st):
return st == self._st
def mount(self, ctx):
seen["who"] = self._tag
cm = CallerMount(
Fake(SubjectType.ACCOUNT, "acct"),
Fake(SubjectType.EXTERNAL_SSO, "sso"),
)
cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO))
assert seen == {"who": "sso"}
def test_caller_mount_raises_when_none_applies():
with pytest.raises(Unauthorized):
CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT))

View File

@ -1,25 +0,0 @@
import pytest
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import ScopeCheck
def _ctx(scopes, required):
c = Context(required_scope=required)
c.scopes = frozenset(scopes)
return c
def test_scope_check_passes_on_full():
ScopeCheck()(_ctx({"full"}, "apps:run"))
def test_scope_check_passes_on_explicit_match():
ScopeCheck()(_ctx({"apps:run"}, "apps:run"))
def test_scope_check_rejects_when_missing():
with pytest.raises(Forbidden) as exc:
ScopeCheck()(_ctx({"apps:read"}, "apps:run"))
assert "insufficient_scope" in str(exc.value.description)

View File

@ -1,239 +0,0 @@
"""Surface gate tests.
The gate has two attachment forms — decorator (`accept_subjects`) and
pipeline step (`SurfaceCheck`) — and both must:
- 403 on mismatched subject type with a canonical-path hint
- emit `openapi.wrong_surface_denied` once with the right payload
- pass-through on match
- raise RuntimeError (not 403) if the auth ContextVar is unset — that's
a wiring bug, not a user-driven failure
Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with
`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx`
context manager so the ContextVar resets even when an assertion fails;
that keeps state from leaking into the next test on the same worker.
"""
from __future__ import annotations
import uuid
from collections.abc import Iterator
from contextlib import contextmanager
from datetime import UTC, datetime
from unittest.mock import patch
import pytest
from flask import Flask
from werkzeug.exceptions import Forbidden
from controllers.openapi.auth.context import Context
from controllers.openapi.auth.steps import SurfaceCheck
from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface
from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx
@contextmanager
def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]:
token = set_auth_ctx(ctx)
try:
yield
finally:
reset_auth_ctx(token)
def _account_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.ACCOUNT,
subject_email="user@example.com",
subject_issuer="dify:account",
account_id=uuid.uuid4(),
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
expires_at=datetime.now(UTC),
token_hash="h1",
verified_tenants={},
)
def _sso_ctx() -> AuthContext:
return AuthContext(
subject_type=SubjectType.EXTERNAL_SSO,
subject_email="sso@partner.com",
subject_issuer="https://idp.partner.com",
account_id=None,
client_id="difyctl",
scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}),
token_id=uuid.uuid4(),
source="oauth_external_sso",
expires_at=datetime.now(UTC),
token_hash="h2",
verified_tenants={},
)
# ---------------------------------------------------------------------------
# check_surface — shared core
# ---------------------------------------------------------------------------
def test_check_surface_passes_when_subject_in_accepted():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()):
check_surface(frozenset({SubjectType.ACCOUNT})) # no raise
def test_check_surface_rejects_on_wrong_subject_and_emits_audit():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden) as exc:
check_surface(frozenset({SubjectType.EXTERNAL_SSO}))
assert "wrong_surface" in exc.value.description
# canonical-path hint should point at the caller's surface,
# not the surface they were rejected from
assert "/openapi/v1/apps" in exc.value.description
emit.assert_called_once()
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.ACCOUNT.value
assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps"
assert kwargs["client_id"] == "difyctl"
assert kwargs["token_id"] is not None
def test_check_surface_rejects_sso_on_account_surface():
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
check_surface(frozenset({SubjectType.ACCOUNT}))
kwargs = emit.call_args.kwargs
assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value
def test_check_surface_runtime_error_when_auth_ctx_missing():
"""Missing auth ContextVar means the bearer layer didn't run — wiring
bug, not a user-driven failure. Surface as RuntimeError (loud) so a
future refactor doesn't accidentally let a route skip authentication
and return a 403 that looks identical to a legitimate wrong-surface
deny.
"""
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps"):
with pytest.raises(RuntimeError):
check_surface(frozenset({SubjectType.ACCOUNT}))
# ---------------------------------------------------------------------------
# @accept_subjects — decorator form
# ---------------------------------------------------------------------------
def _make_app() -> Flask:
app = Flask(__name__)
@app.route("/account-only")
@accept_subjects(SubjectType.ACCOUNT)
def _account_only():
return "ok"
@app.route("/external-only")
@accept_subjects(SubjectType.EXTERNAL_SSO)
def _external_only():
return "ok"
return app
def test_accept_subjects_decorator_passes_on_match():
app = _make_app()
with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()):
# Re-route through the decorated function by reaching for view_function
view = app.view_functions["_account_only"]
assert view() == "ok"
def test_accept_subjects_decorator_403_on_miss():
app = _make_app()
with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()):
view = app.view_functions["_external_only"]
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"):
with pytest.raises(Forbidden):
view()
# ---------------------------------------------------------------------------
# SurfaceCheck — pipeline step form
# ---------------------------------------------------------------------------
def _pipeline_ctx() -> Context:
# SurfaceCheck reads ``request.path`` from Flask's global request — set up
# via ``app.test_request_context`` in the calling tests — not from Context.
return Context(required_scope=Scope.APPS_RUN)
def test_surface_check_passes_on_match():
step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
step(_pipeline_ctx()) # no raise
def test_surface_check_rejects_on_miss_and_emits_audit():
step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO}))
app = Flask(__name__)
with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()):
with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit:
with pytest.raises(Forbidden):
step(_pipeline_ctx())
emit.assert_called_once()
# ---------------------------------------------------------------------------
# _coerce_subject_type — normalises whatever sat on ctx.subject_type
# ---------------------------------------------------------------------------
#
# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value
# could be a real enum (happy path), a raw string (e.g. rehydrated from a
# dict-shaped context), `None` (attribute missing), or something unexpected
# from a buggy upstream. The coercer must collapse all of that to
# `SubjectType | None` so `check_surface` can do a clean set-membership
# check and emit a clean audit payload.
def test_coerce_subject_type_returns_none_for_none():
assert _coerce_subject_type(None) is None
def test_coerce_subject_type_returns_enum_instance_unchanged():
# Identity matters: we don't want to round-trip through the string
# constructor for an already-valid enum.
assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT
assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO
@pytest.mark.parametrize(
("raw", "expected"),
[
("account", SubjectType.ACCOUNT),
("external_sso", SubjectType.EXTERNAL_SSO),
],
)
def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType):
assert _coerce_subject_type(raw) is expected
def test_coerce_subject_type_raises_on_unknown_string():
# Unknown strings reach `SubjectType(raw)` which raises ValueError.
# We surface that loudly rather than silently returning None, because
# a string that *looks* like a subject type but isn't is almost
# certainly an upstream bug worth catching.
with pytest.raises(ValueError):
_coerce_subject_type("not_a_subject")
@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}])
def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object):
assert _coerce_subject_type(raw) is None

View File

@ -0,0 +1,156 @@
import uuid
from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden, Unauthorized
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.verify import (
check_acl,
check_app_access,
check_membership,
check_private_app_permission,
check_scope,
)
from libs.oauth_bearer import Scope, TokenType
from models.account import Tenant
from models.model import App
from services.enterprise.enterprise_service import WebAppAccessMode
def _data(**kwargs) -> AuthData:
defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})}
defaults.update(kwargs)
return AuthData(**defaults)
# --- check_scope ---
def test_check_scope_passes_when_required_is_none():
check_scope(_data(required_scope=None))
def test_check_scope_passes_when_full_in_scopes():
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL})))
def test_check_scope_passes_when_exact_scope_present():
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN})))
def test_check_scope_raises_forbidden_when_scope_missing():
with pytest.raises(Forbidden, match="insufficient_scope"):
check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ})))
# reject_external_sso is no longer a pipeline step — PipelineRouter._execute raises
# Forbidden("external_sso_requires_ee") directly when route.required_edition is not satisfied.
# Test coverage for this is in test_pipeline.py (test_router_rejects_token_type_on_wrong_edition).
# --- check_membership ---
def test_check_membership_raises_unauthorized_when_tenant_none():
with pytest.raises(Unauthorized):
check_membership(_data(tenant=None))
def test_check_membership_calls_check_workspace_membership():
tenant = MagicMock(spec=Tenant)
tenant.id = "tenant-1"
data = _data(
account_id=uuid.uuid4(),
token_hash="myhash",
tenants={"tenant-1": True},
tenant=tenant,
)
with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm:
check_membership(data)
mock_cwm.assert_called_once_with(
account_id=data.account_id,
tenant_id="tenant-1",
token_hash="myhash",
membership_cache=data.tenants,
)
# --- check_app_access ---
def test_check_app_access_passes_when_tenant_none():
check_app_access(_data(tenant=None))
def test_check_app_access_passes_when_member():
tenant = MagicMock(spec=Tenant)
tenant.id = "t1"
data = _data(account_id=uuid.uuid4(), tenant=tenant)
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True):
check_app_access(data)
def test_check_app_access_raises_when_not_member():
tenant = MagicMock(spec=Tenant)
tenant.id = "t1"
data = _data(account_id=uuid.uuid4(), tenant=tenant)
with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False):
with pytest.raises(Forbidden, match="subject_no_app_access"):
check_app_access(data)
# --- check_acl ---
def test_check_acl_raises_when_app_or_mode_missing():
with pytest.raises(Forbidden):
check_acl(_data(app=None, app_access_mode=None))
def test_check_acl_account_allowed_for_public():
app = MagicMock(spec=App)
data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC)
check_acl(data)
def test_check_acl_external_sso_blocked_for_private():
app = MagicMock(spec=App)
data = _data(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
app=app,
app_access_mode=WebAppAccessMode.PRIVATE,
)
with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"):
check_acl(data)
def test_check_acl_external_sso_allowed_for_sso_verified():
app = MagicMock(spec=App)
data = _data(
token_type=TokenType.OAUTH_EXTERNAL_SSO,
app=app,
app_access_mode=WebAppAccessMode.SSO_VERIFIED,
)
check_acl(data)
# --- check_private_app_permission ---
def test_check_private_app_permission_raises_when_app_none():
with pytest.raises(Forbidden):
check_private_app_permission(_data(app=None))
def test_check_private_app_permission_raises_when_user_not_allowed():
app = MagicMock(spec=App)
app.id = "app-1"
data = _data(account_id=uuid.uuid4(), app=app)
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
with patch(target, return_value=False):
with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"):
check_private_app_permission(data)
def test_check_private_app_permission_passes_when_allowed():
app = MagicMock(spec=App)
app.id = "app-1"
data = _data(account_id=uuid.uuid4(), app=app)
target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp"
with patch(target, return_value=True):
check_private_app_permission(data)

View File

@ -1,20 +1,36 @@
import uuid
import pytest
from flask import Flask
from controllers.openapi import bp as openapi_bp
from controllers.openapi.auth.pipeline import Pipeline
from controllers.openapi.auth.data import AuthData
from controllers.openapi.auth.pipeline import PipelineRouter
from libs.oauth_bearer import Scope, TokenType
def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None):
"""Bypass all auth logic; inject minimal AuthData and call the view directly."""
kwargs["auth_data"] = AuthData(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
token_id=uuid.uuid4(),
scopes=frozenset({Scope.FULL}),
required_scope=scope,
)
return view(*args, **kwargs)
@pytest.fixture
def bypass_pipeline(monkeypatch):
"""Stub Pipeline.run so endpoint decoration does not invoke real auth.
"""Stub PipelineRouter._execute so endpoints skip real auth at request time.
Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real
pipeline at import time; mocking the module attribute does not undo
that. Patching Pipeline.run on the class is the bypass that actually
works.
Module-level @auth_router.guard(...) captures the real router at import
time — patching guard itself does nothing. Patching _execute on the class
is the seam that fires at request time.
"""
monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None)
monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute)
@pytest.fixture

View File

@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id():
"""Account subject scopes queries via account_id."""
import uuid as _uuid
from libs.oauth_bearer import AuthContext, SubjectType
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
from services.oauth_device_flow import subject_match_clauses
aid = _uuid.uuid4()
@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id():
client_id="difyctl",
scopes=frozenset({"full"}),
token_id=_uuid.uuid4(),
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT,
expires_at=None,
token_hash="h1",
verified_tenants={},
@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
"""
import uuid as _uuid
from libs.oauth_bearer import AuthContext, SubjectType
from libs.oauth_bearer import AuthContext, SubjectType, TokenType
from services.oauth_device_flow import subject_match_clauses
ctx = AuthContext(
@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer():
client_id="difyctl",
scopes=frozenset({"apps:run"}),
token_id=_uuid.uuid4(),
source="oauth_external_sso",
token_type=TokenType.OAUTH_EXTERNAL_SSO,
expires_at=None,
token_hash="h1",
verified_tenants={},

View File

@ -57,7 +57,11 @@ def test_stop_task_endpoint_registered(openapi_app):
def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch):
import uuid
from controllers.openapi.app_run import AppRunTaskStopApi
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
queue_mock = Mock()
graph_mock = Mock()
@ -69,15 +73,23 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo
monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock)
monkeypatch.setattr(run_module, "redis_client", object())
auth_data = AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=SimpleNamespace(id="app-1", tenant_id="t-1"),
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
)
api = AppRunTaskStopApi()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"):
result = api.post.__wrapped__(
api,
app_id="app-1",
task_id="task-1",
app_model=SimpleNamespace(id="app-1", tenant_id="t-1"),
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=auth_data,
)
queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1")

View File

@ -4,6 +4,7 @@ from __future__ import annotations
import json
import sys
import uuid
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import Mock
@ -11,9 +12,23 @@ from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
from models.human_input import RecipientType
def _make_auth_data(app_model, caller, caller_kind):
return AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=app_model,
caller=caller,
caller_kind=caller_kind,
)
class TestOpenApiHumanInputFormGet:
def test_get_success(self, app, bypass_pipeline, monkeypatch):
from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi
@ -43,15 +58,14 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
resp = api.get.__wrapped__(
api,
app_id="app-1",
form_token="tok-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
payload = json.loads(resp.get_data(as_text=True))
@ -71,6 +85,7 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"):
with pytest.raises(NotFound):
@ -78,9 +93,7 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="bad",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch):
@ -97,6 +110,7 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
@ -104,9 +118,7 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="tok-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch):
@ -126,6 +138,7 @@ class TestOpenApiHumanInputFormGet:
api = OpenApiWorkflowHumanInputFormApi()
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"):
with pytest.raises(NotFound):
@ -133,9 +146,7 @@ class TestOpenApiHumanInputFormGet:
api,
app_id="app-1",
form_token="tok-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
@ -172,9 +183,7 @@ class TestOpenApiHumanInputFormPost:
api,
app_id="app-1",
form_token="tok-1",
app_model=app_model,
caller=caller,
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
service_mock.submit_form_by_token.assert_called_once_with(
@ -211,9 +220,7 @@ class TestOpenApiHumanInputFormPost:
api,
app_id="app-1",
form_token="tok-1",
app_model=app_model,
caller=caller,
caller_kind="end_user",
auth_data=_make_auth_data(app_model, caller, "end_user"),
)
service_mock.submit_form_by_token.assert_called_once_with(

View File

@ -3,15 +3,30 @@
from __future__ import annotations
import sys
import uuid
from types import SimpleNamespace
from unittest.mock import Mock
import pytest
from werkzeug.exceptions import NotFound
from controllers.openapi.auth.data import AuthData
from libs.oauth_bearer import Scope, TokenType
from models.enums import CreatorUserRole
def _make_auth_data(app_model, caller, caller_kind):
return AuthData.model_construct(
token_type=TokenType.OAUTH_ACCOUNT,
account_id=uuid.uuid4(),
token_hash="test",
scopes=frozenset({Scope.FULL}),
app=app_model,
caller=caller,
caller_kind=caller_kind,
)
def _make_workflow_run(
*,
app_id="app-1",
@ -50,6 +65,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
with pytest.raises(NotFound):
@ -57,9 +73,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch):
@ -77,6 +91,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
with pytest.raises(NotFound):
@ -84,9 +99,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch):
@ -115,6 +128,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -123,9 +137,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
assert resp.mimetype == "text/event-stream"
@ -143,6 +155,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -151,9 +164,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch):
@ -179,6 +190,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="eu-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -186,9 +198,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="eu-1"),
caller_kind="end_user",
auth_data=_make_auth_data(app_model, caller, "end_user"),
)
assert resp.mimetype == "text/event-stream"
@ -222,6 +232,7 @@ class TestOpenApiWorkflowEventsApi:
from models.model import AppMode
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW)
caller = SimpleNamespace(id="acct-1")
api = self._get_api()
with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"):
@ -229,9 +240,7 @@ class TestOpenApiWorkflowEventsApi:
api,
app_id="app-1",
task_id="wf-run-1",
app_model=app_model,
caller=SimpleNamespace(id="acct-1"),
caller_kind="account",
auth_data=_make_auth_data(app_model, caller, "account"),
)
assert resp.mimetype == "text/event-stream"
chunks = list(resp.response)

View File

@ -11,6 +11,7 @@ from libs.oauth_bearer import (
SubjectType,
TokenKind,
TokenKindRegistry,
TokenType,
)
@ -21,7 +22,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry:
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT,
resolver=resolver,
)
]
@ -63,7 +64,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer():
prefix="dfoa_",
subject_type=SubjectType.ACCOUNT,
scopes=frozenset({Scope.FULL}),
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT,
resolver=MagicMock(),
)
]

View File

@ -19,6 +19,7 @@ from libs.oauth_bearer import (
AuthContext,
Scope,
SubjectType,
TokenType,
require_scope,
reset_auth_ctx,
set_auth_ctx,
@ -50,7 +51,7 @@ def _ctx(scopes) -> AuthContext:
client_id="difyctl",
scopes=scopes,
token_id=uuid.uuid4(),
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT,
expires_at=None,
token_hash="h1",
verified_tenants={},

View File

@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch
import pytest
from werkzeug.exceptions import Forbidden
from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member
from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member
def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext:
@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au
client_id="difyctl",
scopes=frozenset({Scope.FULL}),
token_id=uuid.uuid4(),
source="oauth_account",
token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO,
expires_at=None,
token_hash="h1",
verified_tenants=dict(verified or {}),

View File

@ -787,7 +787,7 @@ export const request = async<T>(url: string, options = {}, otherOptions?: IOther
isPublicAPI = false,
silent,
} = otherOptionsForBaseFetch
if (isPublicAPI && code === 'unauthorized' && IS_CE_EDITION) {
if (isPublicAPI && code === 'unauthorized') {
requiredWebSSOLogin()
return Promise.reject(err)
}