From d2788d7aba123e646c63be00213d341a3f601896 Mon Sep 17 00:00:00 2001 From: Xiyuan Chen <52963600+GareArc@users.noreply.github.com> Date: Wed, 27 May 2026 05:45:30 -0700 Subject: [PATCH] feat(openapi): redesign auth pipeline with per-token-type routing (#36693) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/controllers/openapi/account.py | 65 ++-- api/controllers/openapi/app_run.py | 13 +- api/controllers/openapi/apps.py | 53 +-- .../openapi/apps_permitted_external.py | 28 +- api/controllers/openapi/auth/__init__.py | 4 +- api/controllers/openapi/auth/composition.py | 100 +++--- api/controllers/openapi/auth/conditions.py | 53 +++ api/controllers/openapi/auth/context.py | 68 ---- api/controllers/openapi/auth/data.py | 69 ++++ api/controllers/openapi/auth/flow.py | 19 ++ api/controllers/openapi/auth/pipeline.py | 224 +++++++++++-- api/controllers/openapi/auth/prepare.py | 67 ++++ api/controllers/openapi/auth/steps.py | 170 ---------- api/controllers/openapi/auth/strategies.py | 168 ---------- api/controllers/openapi/auth/verify.py | 82 +++++ api/controllers/openapi/files.py | 9 +- api/controllers/openapi/human_input_form.py | 13 +- api/controllers/openapi/workflow_events.py | 10 +- api/controllers/openapi/workspaces.py | 96 ++---- api/libs/oauth_bearer.py | 25 +- .../openapi/auth/test_composition.py | 115 +++---- .../openapi/auth/test_conditions.py | 143 +++++++++ .../controllers/openapi/auth/test_context.py | 21 -- .../controllers/openapi/auth/test_data.py | 117 +++++++ .../controllers/openapi/auth/test_flow.py | 42 +++ .../controllers/openapi/auth/test_pipeline.py | 302 +++++++++++++++--- .../controllers/openapi/auth/test_prepare.py | 183 +++++++++++ .../openapi/auth/test_role_gate.py | 6 +- .../openapi/auth/test_step_app_resolver.py | 64 ---- .../openapi/auth/test_step_authz.py | 76 ----- .../openapi/auth/test_step_bearer.py | 83 ----- .../openapi/auth/test_step_layer0.py | 157 --------- .../openapi/auth/test_step_mount.py | 77 ----- .../openapi/auth/test_step_scope.py | 25 -- .../openapi/auth/test_surface_gate.py | 239 -------------- .../controllers/openapi/auth/test_verify.py | 142 ++++++++ .../controllers/openapi/conftest.py | 30 +- .../controllers/openapi/test_account.py | 8 +- .../openapi/test_app_run_streaming.py | 18 +- .../openapi/test_human_input_form.py | 43 +-- .../openapi/test_workflow_events_openapi.py | 45 +-- .../openapi/test_workspaces_members.py | 66 ++-- .../test_oauth_bearer_rate_limit_ordering.py | 5 +- .../libs/test_oauth_bearer_require_scope.py | 3 +- .../libs/test_workspace_member_helper.py | 4 +- .../services/test_oauth_device_flow.py | 6 +- 46 files changed, 1740 insertions(+), 1616 deletions(-) create mode 100644 api/controllers/openapi/auth/conditions.py delete mode 100644 api/controllers/openapi/auth/context.py create mode 100644 api/controllers/openapi/auth/data.py create mode 100644 api/controllers/openapi/auth/flow.py create mode 100644 api/controllers/openapi/auth/prepare.py delete mode 100644 api/controllers/openapi/auth/steps.py delete mode 100644 api/controllers/openapi/auth/strategies.py create mode 100644 api/controllers/openapi/auth/verify.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_conditions.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_context.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_data.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_flow.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_prepare.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py delete mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py create mode 100644 api/tests/unit_tests/controllers/openapi/auth/test_verify.py diff --git a/api/controllers/openapi/account.py b/api/controllers/openapi/account.py index 602d7e7ab4..256a822dcb 100644 --- a/api/controllers/openapi/account.py +++ b/api/controllers/openapi/account.py @@ -4,7 +4,7 @@ from datetime import UTC, datetime from flask import request from flask_restx import Resource -from werkzeug.exceptions import BadRequest, NotFound +from werkzeug.exceptions import NotFound from controllers.openapi import openapi_ns from controllers.openapi._models import ( @@ -17,18 +17,17 @@ from controllers.openapi._models import ( SessionRow, WorkspacePayload, ) +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - AuthContext, - SubjectType, + Scope, + TokenType, get_auth_ctx, - validate_bearer, ) from libs.rate_limit import ( LIMIT_ME_PER_ACCOUNT, - LIMIT_ME_PER_EMAIL, enforce, ) from services.account_service import AccountService, TenantService @@ -42,32 +41,18 @@ from services.oauth_device_flow import ( @openapi_ns.route("/account") class AccountApi(Resource): @openapi_ns.response(200, "Account info", openapi_ns.models[AccountResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def get(self): - ctx = get_auth_ctx() + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): + enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{auth_data.account_id}") - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - enforce(LIMIT_ME_PER_EMAIL, key=f"subject:{ctx.subject_email}") - else: - enforce(LIMIT_ME_PER_ACCOUNT, key=f"account:{ctx.account_id}") - - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - return AccountResponse( - subject_type=ctx.subject_type, - subject_email=ctx.subject_email, - subject_issuer=ctx.subject_issuer, - account=None, - workspaces=[], - default_workspace_id=None, - ).model_dump(mode="json") - - account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) if ctx.account_id else None - memberships = TenantService.get_account_memberships(db.session, str(ctx.account_id)) if ctx.account_id else [] + account_id_str = str(auth_data.account_id) if auth_data.account_id else None + account = AccountService.get_account_by_id(db.session, account_id_str) if account_id_str else None + memberships = TenantService.get_account_memberships(db.session, account_id_str) if account_id_str else [] default_ws_id = _pick_default_workspace(memberships) return AccountResponse( - subject_type=ctx.subject_type, - subject_email=ctx.subject_email or (account.email if account else None), + subject_type="account", + subject_email=account.email if account else None, account=_account_payload(account) if account else None, workspaces=[_workspace_payload(m) for m in memberships], default_workspace_id=default_ws_id, @@ -77,19 +62,17 @@ class AccountApi(Resource): @openapi_ns.route("/account/sessions/self") class AccountSessionsSelfApi(Resource): @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def delete(self): - ctx = get_auth_ctx() - _require_oauth_subject(ctx) - revoke_oauth_token(db.session, redis_client, str(ctx.token_id)) + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def delete(self, *, auth_data: AuthData): + revoke_oauth_token(db.session, redis_client, str(auth_data.token_id)) return RevokeResponse(status="revoked").model_dump(mode="json"), 200 @openapi_ns.route("/account/sessions") class AccountSessionsApi(Resource): @openapi_ns.response(200, "Session list", openapi_ns.models[SessionListResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def get(self): + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): ctx = get_auth_ctx() now = datetime.now(UTC) page = int(request.args.get("page", "1")) @@ -122,10 +105,9 @@ class AccountSessionsApi(Resource): @openapi_ns.route("/account/sessions/") class AccountSessionByIdApi(Resource): @openapi_ns.response(200, "Session revoked", openapi_ns.models[RevokeResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - def delete(self, session_id: str): + @auth_router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def delete(self, session_id: str, *, auth_data: AuthData): ctx = get_auth_ctx() - _require_oauth_subject(ctx) # 404 (not 403) on cross-subject so the endpoint doesn't leak # token IDs that belong to other subjects. @@ -136,13 +118,6 @@ class AccountSessionByIdApi(Resource): return RevokeResponse(status="revoked").model_dump(mode="json"), 200 -def _require_oauth_subject(ctx: AuthContext) -> None: - if not ctx.source.startswith("oauth"): - raise BadRequest( - "this endpoint revokes OAuth bearer tokens; use /openapi/v1/personal-access-tokens/self for PATs" - ) - - def _iso(dt: datetime | None) -> str | None: if dt is None: return None diff --git a/api/controllers/openapi/app_run.py b/api/controllers/openapi/app_run.py index 95a26d50fa..8ef94740c9 100644 --- a/api/controllers/openapi/app_run.py +++ b/api/controllers/openapi/app_run.py @@ -16,7 +16,8 @@ import services from controllers.openapi import openapi_ns from controllers.openapi._audit import emit_app_run from controllers.openapi._models import AppRunRequest -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from controllers.service_api.app.error import ( AppUnavailableError, CompletionRequestError, @@ -124,8 +125,9 @@ _DISPATCH: dict[AppMode, Callable[[App, Any, AppRunRequest], Any]] = { class AppRunApi(Resource): @openapi_ns.expect(openapi_ns.models[AppRunRequest.__name__]) @openapi_ns.response(200, "Run result (SSE stream)") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() body = request.get_json(silent=True) or {} try: payload = AppRunRequest.model_validate(body) @@ -158,8 +160,9 @@ class AppRunApi(Resource): @openapi_ns.route("/apps//tasks//stop") class AppRunTaskStopApi(Resource): @openapi_ns.response(200, "Task stopped") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, task_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() AppQueueManager.set_stop_flag_no_user_check(task_id) GraphEngineManager(redis_client).send_stop_command(task_id) return {"result": "success"} diff --git a/api/controllers/openapi/apps.py b/api/controllers/openapi/apps.py index 8a3fc81809..d3bc4e4680 100644 --- a/api/controllers/openapi/apps.py +++ b/api/controllers/openapi/apps.py @@ -1,9 +1,4 @@ -"""GET /openapi/v1/apps and per-app reads. - -Decorator order: `method_decorators` is innermost-first. `validate_bearer` -is last → outermost → publishes the auth ContextVar before `require_scope` -reads it. -""" +"""GET /openapi/v1/apps and per-app reads.""" from __future__ import annotations @@ -28,31 +23,17 @@ from controllers.openapi._models import ( AppListRow, TagItem, ) -from controllers.openapi.auth.surface_gate import accept_subjects +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from controllers.service_api.app.error import AppUnavailableError from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict from extensions.ext_database import db -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - AuthContext, - Scope, - SubjectType, - get_auth_ctx, - require_scope, - require_workspace_member, - validate_bearer, -) +from libs.oauth_bearer import Scope, TokenType from models import App from services.account_service import TenantService from services.app_service import AppListParams, AppService from services.tag_service import TagService -_APPS_READ_DECORATORS = [ - require_scope(Scope.APPS_READ), - accept_subjects(SubjectType.ACCOUNT), - validate_bearer(accept=ACCEPT_USER_ANY), -] - _ALLOWED_DESCRIBE_FIELDS: frozenset[str] = frozenset({"info", "parameters", "input_schema"}) @@ -66,13 +47,9 @@ _EMPTY_PARAMETERS: dict[str, Any] = { class AppReadResource(Resource): - """Base for per-app read endpoints; subclasses call `_load()` for SSO/membership/exists checks.""" - - method_decorators = _APPS_READ_DECORATORS - - def _load(self, app_id: str, workspace_id: str | None = None) -> tuple[App, AuthContext]: - ctx: AuthContext = get_auth_ctx() + """Base for per-app read endpoints; subclasses call `_load()` for membership/exists checks.""" + def _load(self, app_id: str, workspace_id: str | None = None) -> App: try: parsed_uuid = _uuid.UUID(app_id) is_uuid = True @@ -99,8 +76,7 @@ class AppReadResource(Resource): raise Conflict("".join(lines)) app = matches[0] - require_workspace_member(ctx, str(app.tenant_id)) - return app, ctx + return app def parameters_payload(app: App) -> dict: @@ -114,13 +90,14 @@ def parameters_payload(app: App) -> dict: class AppDescribeApi(AppReadResource): @openapi_ns.doc(params=query_params_from_model(AppDescribeQuery)) @openapi_ns.response(200, "App description", openapi_ns.models[AppDescribeResponse.__name__]) - def get(self, app_id: str): + @auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, app_id: str, *, auth_data: AuthData): try: query = AppDescribeQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: raise UnprocessableEntity(exc.json()) - app, _ = self._load(app_id, workspace_id=query.workspace_id) + app = self._load(app_id, workspace_id=query.workspace_id) requested = query.fields want_info = requested is None or "info" in requested @@ -168,20 +145,16 @@ class AppDescribeApi(AppReadResource): @openapi_ns.route("/apps") class AppListApi(Resource): - method_decorators = _APPS_READ_DECORATORS - @openapi_ns.doc(params=query_params_from_model(AppListQuery)) @openapi_ns.response(200, "App list", openapi_ns.models[AppListResponse.__name__]) - def get(self): - ctx: AuthContext = get_auth_ctx() - + @auth_router.guard(scope=Scope.APPS_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): try: query: AppListQuery = AppListQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: raise UnprocessableEntity(exc.json()) workspace_id = query.workspace_id - require_workspace_member(ctx, workspace_id) empty = ( AppListResponse(page=query.page, limit=query.limit, total=0, has_more=False, data=[]).model_dump( @@ -237,7 +210,7 @@ class AppListApi(Resource): openapi_visible=True, ) - pagination = AppService().get_paginate_apps(str(ctx.account_id), workspace_id, params) + pagination = AppService().get_paginate_apps(str(auth_data.account_id), workspace_id, params) if pagination is None: return empty diff --git a/api/controllers/openapi/apps_permitted_external.py b/api/controllers/openapi/apps_permitted_external.py index 9359dca228..f86fd34a19 100644 --- a/api/controllers/openapi/apps_permitted_external.py +++ b/api/controllers/openapi/apps_permitted_external.py @@ -18,37 +18,27 @@ from controllers.openapi._models import ( PermittedExternalAppsListQuery, PermittedExternalAppsListResponse, ) -from controllers.openapi.auth.surface_gate import accept_subjects +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData, Edition from extensions.ext_database import db -from libs.device_flow_security import enterprise_only -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - Scope, - SubjectType, - require_scope, - validate_bearer, -) +from libs.oauth_bearer import Scope, TokenType from models import App from services.account_service import TenantService from services.app_service import AppService from services.enterprise.app_permitted_service import list_permitted_apps -from services.openapi.license_gate import license_required @openapi_ns.route("/permitted-external-apps") class PermittedExternalAppsListApi(Resource): - method_decorators = [ - require_scope(Scope.APPS_READ_PERMITTED_EXTERNAL), - license_required, - accept_subjects(SubjectType.EXTERNAL_SSO), - validate_bearer(accept=ACCEPT_USER_ANY), - enterprise_only, - ] - @openapi_ns.response( 200, "Permitted external apps list", openapi_ns.models[PermittedExternalAppsListResponse.__name__] ) - def get(self): + @auth_router.guard( + scope=Scope.APPS_READ_PERMITTED_EXTERNAL, + allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO}), + edition=frozenset({Edition.EE}), + ) + def get(self, *, auth_data: AuthData): try: query = PermittedExternalAppsListQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: diff --git a/api/controllers/openapi/auth/__init__.py b/api/controllers/openapi/auth/__init__.py index 17ac5493d0..0460788c18 100644 --- a/api/controllers/openapi/auth/__init__.py +++ b/api/controllers/openapi/auth/__init__.py @@ -1,3 +1,3 @@ -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router -__all__ = ["OAUTH_BEARER_PIPELINE"] +__all__ = ["auth_router"] diff --git a/api/controllers/openapi/auth/composition.py b/api/controllers/openapi/auth/composition.py index 973ddd75a2..c2c3e12873 100644 --- a/api/controllers/openapi/auth/composition.py +++ b/api/controllers/openapi/auth/composition.py @@ -1,46 +1,64 @@ -"""`OAUTH_BEARER_PIPELINE` — the auth scheme for openapi `/run` endpoints. - -Endpoints attach via `@OAUTH_BEARER_PIPELINE.guard(scope=…)`. No alternative -paths. Read endpoints (`/apps`, `/info`, `/parameters`, `/describe`) skip -the pipeline and use `validate_bearer + require_scope + require_workspace_member` -inline — they don't need `AppAuthzCheck`/`CallerMount`. -""" - from __future__ import annotations -from controllers.openapi.auth.pipeline import Pipeline -from controllers.openapi.auth.steps import ( - AppAuthzCheck, - AppResolver, - BearerCheck, - CallerMount, - ScopeCheck, - SurfaceCheck, - WorkspaceMembershipCheck, +from controllers.openapi.auth.conditions import ( + EDITION_CE, + EDITION_EE, + LOADED_APP_IS_PRIVATE, + PATH_HAS_APP_ID, + WEBAPP_AUTH_ENABLED, ) -from controllers.openapi.auth.strategies import ( - AccountMounter, - AclStrategy, - AppAuthzStrategy, - EndUserMounter, - MembershipStrategy, +from controllers.openapi.auth.data import Edition +from controllers.openapi.auth.flow import When +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from controllers.openapi.auth.prepare import ( + load_account, + load_app, + load_app_access_mode, + load_tenant, + resolve_external_user, ) -from libs.oauth_bearer import SubjectType -from services.feature_service import FeatureService - - -def _resolve_app_authz_strategy() -> AppAuthzStrategy: - if FeatureService.get_system_features().webapp_auth.enabled: - return AclStrategy() - return MembershipStrategy() - - -OAUTH_BEARER_PIPELINE = Pipeline( - BearerCheck(), - SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})), - ScopeCheck(), - AppResolver(), - WorkspaceMembershipCheck(), - AppAuthzCheck(_resolve_app_authz_strategy), - CallerMount(AccountMounter(), EndUserMounter()), +from controllers.openapi.auth.verify import ( + check_acl, + check_app_access, + check_membership, + check_private_app_permission, + check_scope, +) +from libs.oauth_bearer import TokenType + +account_pipeline = AuthPipeline( + prepare=[ + When(PATH_HAS_APP_ID, then=load_app), + When(PATH_HAS_APP_ID, then=load_tenant), + load_account, # all tokens here are account tokens + When(PATH_HAS_APP_ID & EDITION_EE, then=load_app_access_mode), + ], + auth=[ + check_scope, + When(EDITION_CE & PATH_HAS_APP_ID, then=check_membership), + When(EDITION_EE & PATH_HAS_APP_ID & ~WEBAPP_AUTH_ENABLED, then=check_app_access), + When(PATH_HAS_APP_ID & EDITION_EE & WEBAPP_AUTH_ENABLED, then=check_acl), + When(EDITION_EE & LOADED_APP_IS_PRIVATE, then=check_private_app_permission), + ], +) + +external_sso_pipeline = AuthPipeline( + prepare=[ + When(PATH_HAS_APP_ID, then=load_app), + When(PATH_HAS_APP_ID, then=load_tenant), + When(PATH_HAS_APP_ID, then=resolve_external_user), + When(PATH_HAS_APP_ID, then=load_app_access_mode), + ], + auth=[ + check_scope, + When(PATH_HAS_APP_ID & WEBAPP_AUTH_ENABLED, then=check_acl), + When(LOADED_APP_IS_PRIVATE, then=check_private_app_permission), + ], +) + +auth_router = PipelineRouter( + { + TokenType.OAUTH_ACCOUNT: PipelineRoute(account_pipeline), + TokenType.OAUTH_EXTERNAL_SSO: PipelineRoute(external_sso_pipeline, required_edition=frozenset({Edition.EE})), + } ) diff --git a/api/controllers/openapi/auth/conditions.py b/api/controllers/openapi/auth/conditions.py new file mode 100644 index 0000000000..2399fc04f1 --- /dev/null +++ b/api/controllers/openapi/auth/conditions.py @@ -0,0 +1,53 @@ +from __future__ import annotations + +from collections.abc import Callable + +from controllers.openapi.auth.data import AuthData, Edition, RequestContext, current_edition +from libs.oauth_bearer import TokenType +from services.enterprise.enterprise_service import WebAppAccessMode +from services.feature_service import FeatureService + +CondFn = Callable[[RequestContext, AuthData | None], bool] + + +class Cond: + def __init__(self, fn: CondFn) -> None: + self._fn = fn + + def __call__(self, ctx: RequestContext, data: AuthData | None = None) -> bool: + return self._fn(ctx, data) + + def __and__(self, other: Cond) -> Cond: + return Cond(lambda ctx, data: self(ctx, data) and other(ctx, data)) + + def __or__(self, other: Cond) -> Cond: + return Cond(lambda ctx, data: self(ctx, data) or other(ctx, data)) + + def __invert__(self) -> Cond: + return Cond(lambda ctx, data: not self(ctx, data)) + + +def request_cond(fn: Callable[[RequestContext], bool]) -> Cond: + return Cond(lambda ctx, _: fn(ctx)) + + +def data_cond(fn: Callable[[AuthData], bool]) -> Cond: + return Cond(lambda _, data: data is not None and fn(data)) + + +def config_cond(fn: Callable[[], bool]) -> Cond: + return Cond(lambda _, __: fn()) + + +TOKEN_IS_OAUTH_ACCOUNT = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT) +TOKEN_IS_OAUTH_EXTERNAL_SSO = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_EXTERNAL_SSO) + +PATH_HAS_APP_ID = request_cond(lambda ctx: "app_id" in ctx.path_params) + +EDITION_CE = config_cond(lambda: current_edition() == Edition.CE) +EDITION_EE = config_cond(lambda: current_edition() == Edition.EE) +EDITION_SAAS = config_cond(lambda: current_edition() == Edition.SAAS) + +WEBAPP_AUTH_ENABLED = config_cond(lambda: FeatureService.get_system_features().webapp_auth.enabled) + +LOADED_APP_IS_PRIVATE = data_cond(lambda data: data.app_access_mode == WebAppAccessMode.PRIVATE) diff --git a/api/controllers/openapi/auth/context.py b/api/controllers/openapi/auth/context.py deleted file mode 100644 index 95013627f0..0000000000 --- a/api/controllers/openapi/auth/context.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Mutable per-request context for the openapi auth pipeline. - -Every field starts None / empty and is filled in by a step. The pipeline -is the only thing that should construct or mutate Context — handlers -read populated values via the decorator's kwargs unpacking. - -Context is intentionally decoupled from Flask's ``Request``: the pipeline -guard extracts whatever transport-level inputs the steps need (bearer -token, path params) at the boundary and writes them into Context fields, -so steps stay testable without a request object and won't leak coupling -to a specific framework. -""" - -from __future__ import annotations - -import uuid -from collections.abc import Mapping -from contextvars import Token -from dataclasses import dataclass, field -from datetime import datetime -from typing import TYPE_CHECKING, Literal, Protocol - -from werkzeug.exceptions import Unauthorized - -from libs.oauth_bearer import AuthContext, Scope, SubjectType - -if TYPE_CHECKING: - from models import App, Tenant - - -@dataclass -class Context: - required_scope: Scope - bearer_token: str | None = None - path_params: Mapping[str, str] = field(default_factory=dict) - subject_type: SubjectType | None = None - subject_email: str | None = None - subject_issuer: str | None = None - account_id: uuid.UUID | None = None - scopes: frozenset[Scope] = field(default_factory=frozenset) - token_id: uuid.UUID | None = None - token_hash: str | None = None - cached_verified_tenants: dict[str, bool] | None = None - source: str | None = None - expires_at: datetime | None = None - app: App | None = None - tenant: Tenant | None = None - caller: object | None = None - caller_kind: Literal["account", "end_user"] | None = None - auth_ctx_reset_token: Token[AuthContext] | None = None - - @property - def must_tenant(self) -> Tenant: - if not self.tenant: - raise Unauthorized("tenant is not associated") - return self.tenant - - @property - def must_subject_type(self) -> SubjectType: - if not self.subject_type: - raise Unauthorized("subject_type unset — BearerCheck did not run") - return self.subject_type - - -class Step(Protocol): - """One responsibility. Mutate ctx or raise to short-circuit.""" - - def __call__(self, ctx: Context) -> None: ... diff --git a/api/controllers/openapi/auth/data.py b/api/controllers/openapi/auth/data.py new file mode 100644 index 0000000000..30973b5e9b --- /dev/null +++ b/api/controllers/openapi/auth/data.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import uuid +from enum import StrEnum +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field +from werkzeug.exceptions import InternalServerError + +from configs import dify_config +from libs.oauth_bearer import Scope, TokenType +from models.account import Account, Tenant +from models.model import App, EndUser +from services.enterprise.enterprise_service import WebAppAccessMode + + +class Edition(StrEnum): + CE = "ce" + EE = "ee" + SAAS = "saas" + + +def current_edition() -> Edition: + if dify_config.EDITION == "CLOUD": + return Edition.SAAS + if dify_config.ENTERPRISE_ENABLED: + return Edition.EE + return Edition.CE + + +class ExternalIdentity(BaseModel): + model_config = ConfigDict(frozen=True) + + email: str + issuer: str | None = None + + +class RequestContext(BaseModel): + model_config = ConfigDict(frozen=True) + + token_type: TokenType + scope: Scope | None = None + path_params: dict[str, str] + + +class AuthData(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + + required_scope: Scope | None = None + token_type: TokenType + account_id: uuid.UUID | None = None + token_hash: str + token_id: uuid.UUID | None = None + scopes: frozenset[Scope] + tenants: dict[str, bool] = Field(default_factory=dict) + external_identity: ExternalIdentity | None = None + path_params: dict[str, str] = Field(default_factory=dict) + + app: App | None = None + tenant: Tenant | None = None + app_access_mode: WebAppAccessMode | None = None + + caller: Account | EndUser | None = None + caller_kind: Literal["account", "end_user"] | None = None + + def require_app_context(self) -> tuple[App, Account | EndUser, Literal["account", "end_user"]]: + if self.app is None or self.caller is None or self.caller_kind is None: + raise InternalServerError("pipeline_invariant_violated: app context missing") + return self.app, self.caller, self.caller_kind diff --git a/api/controllers/openapi/auth/flow.py b/api/controllers/openapi/auth/flow.py new file mode 100644 index 0000000000..eee1378cf4 --- /dev/null +++ b/api/controllers/openapi/auth/flow.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from controllers.openapi.auth.conditions import Cond +from controllers.openapi.auth.data import AuthData, RequestContext + + +class When: + def __init__(self, condition: Cond, *, then: Callable[[Any], None]) -> None: + self.condition = condition + self._step = then + + def applies(self, ctx: RequestContext, data: AuthData | None = None) -> bool: + return self.condition(ctx, data) + + def __call__(self, arg: Any) -> None: + self._step(arg) diff --git a/api/controllers/openapi/auth/pipeline.py b/api/controllers/openapi/auth/pipeline.py index 096b1b7ea3..e992e5e5ab 100644 --- a/api/controllers/openapi/auth/pipeline.py +++ b/api/controllers/openapi/auth/pipeline.py @@ -1,51 +1,209 @@ -"""Pipeline IS the auth scheme. +"""Auth pipeline — entry point for all openapi auth. -`Pipeline.guard(scope=…)` is the only attachment point for endpoints — -that is the design lock-in: forgetting an auth layer is structurally -impossible because there is no "sometimes wrap, sometimes don't" choice. +`PipelineRouter.guard()` is the only attachment point for endpoints. +`AuthPipeline` is a pure step-runner with no routing concerns. +`PipelineRoute` binds a pipeline to optional edition requirements. """ from __future__ import annotations +from collections.abc import Callable +from dataclasses import dataclass from functools import wraps +from typing import Any -from flask import request +from flask import current_app, request +from flask_login import user_logged_in +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from controllers.openapi.auth.context import Context, Step -from libs.oauth_bearer import Scope, extract_bearer, reset_auth_ctx +from controllers.openapi._audit import emit_wrong_surface +from controllers.openapi.auth.data import ( + AuthData, + Edition, + ExternalIdentity, + RequestContext, + current_edition, +) +from controllers.openapi.auth.flow import When +from libs.oauth_bearer import ( + AuthContext, + Scope, + TokenType, + extract_bearer, + get_authenticator, + reset_auth_ctx, + set_auth_ctx, +) +from services.feature_service import FeatureService, LicenseStatus -class Pipeline: - def __init__(self, *steps: Step) -> None: - self._steps = steps +class AuthPipeline: + """Pure step-runner — no routing, no guard. - def run(self, ctx: Context) -> None: - for step in self._steps: - step(ctx) + Both `prepare` and `auth` steps receive the same `AuthData` instance. + `prepare` steps populate it; `auth` steps validate it. + """ - def guard(self, *, scope: Scope): - def decorator(view): + def __init__(self, prepare: list, auth: list) -> None: + self._prepare = prepare + self._auth = auth + + def _run( + self, + identity: AuthContext, + args: tuple, + kwargs: dict, + view: Callable, + *, + scope: Scope | None, + ) -> Any: + req_ctx = RequestContext( + token_type=identity.token_type, + scope=scope, + path_params=dict(request.view_args or {}), + ) + + data = AuthData( + token_type=identity.token_type, + account_id=identity.account_id, + token_hash=identity.token_hash, + token_id=identity.token_id, + scopes=frozenset(identity.scopes), + tenants=dict(identity.verified_tenants), + required_scope=scope, + path_params=dict(req_ctx.path_params), + external_identity=( + ExternalIdentity(email=identity.subject_email, issuer=identity.subject_issuer) + if identity.subject_email + else None + ), + ) + + for step in self._prepare: + if _should_run(step, req_ctx, data=None): + step(data) + + for step in self._auth: + if _should_run(step, req_ctx, data=data): + step(data) + + reset_token = set_auth_ctx(identity) + if data.caller: + _mount_flask_login(data.caller) + + try: + kwargs["auth_data"] = data + return view(*args, **kwargs) + finally: + reset_auth_ctx(reset_token) + + +@dataclass(frozen=True) +class PipelineRoute: + pipeline: AuthPipeline + required_edition: frozenset[Edition] | None = None + + +class PipelineRouter: + """Entry point for openapi auth. + + `guard()` is the decorator that endpoints attach to. It applies + global gates (edition, token type) then dispatches to the matching + `PipelineRoute` for the token type. + """ + + def __init__(self, routes: dict[TokenType, PipelineRoute]) -> None: + self._routes = routes + + def guard( + self, + *, + scope: Scope | None = None, + allowed_token_types: frozenset[TokenType] | None = None, + edition: frozenset[Edition] | None = None, + ) -> Callable: + def decorator(view: Callable) -> Callable: @wraps(view) - def decorated(*args, **kwargs): - # Extract transport-level inputs at the boundary so steps - # stay decoupled from Flask's request object. - ctx = Context( - required_scope=scope, - bearer_token=extract_bearer(request), - path_params=dict(request.view_args or {}), + def decorated(*args: Any, **kwargs: Any) -> Any: + return self._execute( + args, + kwargs, + view, + scope=scope, + allowed_token_types=allowed_token_types, + edition=edition, ) - try: - self.run(ctx) - kwargs.update( - app_model=ctx.app, - caller=ctx.caller, - caller_kind=ctx.caller_kind, - ) - return view(*args, **kwargs) - finally: - if ctx.auth_ctx_reset_token is not None: - reset_auth_ctx(ctx.auth_ctx_reset_token) return decorated return decorator + + def _execute( + self, + args: tuple, + kwargs: dict, + view: Callable, + *, + scope: Scope | None, + allowed_token_types: frozenset[TokenType] | None, + edition: frozenset[Edition] | None, + ) -> Any: + # 404 not 403 — this edition doesn't expose the feature at all + if edition is not None and current_edition() not in edition: + raise NotFound() + + license_checked = False + if edition is not None and Edition.EE in edition: + _check_license() + license_checked = True + + token = extract_bearer(request) + if not token: + raise Unauthorized("bearer required") + + identity = get_authenticator().authenticate(token) + + if allowed_token_types is not None and identity.token_type not in allowed_token_types: + emit_wrong_surface( + subject_type=_subject_type_str(identity), + attempted_path=request.path, + client_id=getattr(identity, "client_id", None), + token_id=str(identity.token_id) if identity.token_id else None, + ) + raise Forbidden("unsupported_token_type") + + route = self._routes.get(identity.token_type) + if route is None: + raise Forbidden("unsupported_token_type") + + if route.required_edition is not None: + if current_edition() not in route.required_edition: + raise Forbidden("external_sso_requires_ee") + if not license_checked and Edition.EE in route.required_edition: + _check_license() + + return route.pipeline._run(identity, args, kwargs, view, scope=scope) + + +def _should_run(step: Any, req_ctx: RequestContext, data: AuthData | None) -> bool: + if isinstance(step, When): + return step.applies(req_ctx, data) + return True + + +def _subject_type_str(identity: Any) -> str | None: + subject = getattr(identity, "subject_type", None) + if subject is None: + return None + return subject.value if hasattr(subject, "value") else str(subject) + + +def _check_license() -> None: + settings = FeatureService.get_system_features() + if settings.license.status in {LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST}: + raise Forbidden("license_invalid") + + +def _mount_flask_login(user: Any) -> None: + current_app.login_manager._update_request_context_with_user(user) # type: ignore[attr-defined] + user_logged_in.send(current_app._get_current_object(), user=user) # type: ignore[attr-defined] diff --git a/api/controllers/openapi/auth/prepare.py b/api/controllers/openapi/auth/prepare.py new file mode 100644 index 0000000000..fe6e031b50 --- /dev/null +++ b/api/controllers/openapi/auth/prepare.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from werkzeug.exceptions import Forbidden, InternalServerError, NotFound, Unauthorized + +from controllers.openapi.auth.data import AuthData +from core.app.entities.app_invoke_entities import InvokeFrom +from extensions.ext_database import db +from models.account import TenantStatus +from services.account_service import AccountService, TenantService +from services.app_service import AppService +from services.end_user_service import EndUserService +from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode + + +def load_app(data: AuthData) -> None: + app_id = data.path_params["app_id"] + app = AppService.get_app_by_id(db.session, app_id) + if not app or app.status != "normal": + raise NotFound("app not found") + if not app.enable_api: + raise Forbidden("service_api_disabled") + data.app = app + + +def load_tenant(data: AuthData) -> None: + if data.app is None: + raise InternalServerError("pipeline_invariant_violated: app not loaded before load_tenant") + tenant = TenantService.get_tenant_by_id(db.session, str(data.app.tenant_id)) + if tenant is None or tenant.status == TenantStatus.ARCHIVE: + raise Forbidden("workspace unavailable") + data.tenant = tenant + + +def load_account(data: AuthData) -> None: + account = AccountService.get_account_by_id(db.session, str(data.account_id)) + if account is None: + raise Unauthorized("account not found") + if data.tenant: + account.current_tenant = data.tenant + data.caller = account + data.caller_kind = "account" + + +def resolve_external_user(data: AuthData) -> None: + if data.tenant is None or data.app is None or data.external_identity is None: + raise Unauthorized("missing context for external user resolution") + end_user = EndUserService.get_or_create_end_user_by_type( + InvokeFrom.OPENAPI, + tenant_id=str(data.tenant.id), + app_id=str(data.app.id), + user_id=data.external_identity.email, + ) + data.caller = end_user + data.caller_kind = "end_user" + + +def load_app_access_mode(data: AuthData) -> None: + if data.app is None: + return + try: + settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(data.app.id)) + if settings is None: + data.app_access_mode = None + return + data.app_access_mode = WebAppAccessMode(settings.access_mode) + except ValueError: + data.app_access_mode = None diff --git a/api/controllers/openapi/auth/steps.py b/api/controllers/openapi/auth/steps.py deleted file mode 100644 index 40a168b489..0000000000 --- a/api/controllers/openapi/auth/steps.py +++ /dev/null @@ -1,170 +0,0 @@ -"""Pipeline steps. Each is one responsibility. - -`BearerCheck` is the only step that touches the token registry; downstream -steps see only the populated `Context`. `BearerCheck` also publishes the -resolved identity to the openapi auth ``ContextVar`` (the same one the -decorator-level :func:`libs.oauth_bearer.validate_bearer` writes to) so the -surface gate and any handler reading the request-scoped context has a single -source of truth across both auth-attach paths. The reset token is stashed -on `ctx.auth_ctx_reset_token`; `Pipeline.guard` resets the ContextVar in -its `finally` so worker-thread reuse can't leak identity across requests. -""" - -from __future__ import annotations - -from collections.abc import Callable - -from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized - -from configs import dify_config -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.strategies import AppAuthzStrategy, CallerMounter -from controllers.openapi.auth.surface_gate import check_surface -from extensions.ext_database import db -from libs.oauth_bearer import ( - AuthContext, - InvalidBearerError, - Scope, - SubjectType, - check_workspace_membership, - get_authenticator, - set_auth_ctx, -) -from models import TenantStatus -from services.account_service import TenantService -from services.app_service import AppService - - -class BearerCheck: - """Resolve bearer → populate identity fields. Rate-limit is enforced - inside `BearerAuthenticator.authenticate`, so no separate step here. - Also publishes the resolved `AuthContext` via - :func:`libs.oauth_bearer.set_auth_ctx` — same shape the decorator-level - ``validate_bearer`` writes — so the surface gate + downstream readers - don't see two different identity sources. The reset token is parked on - ``ctx.auth_ctx_reset_token`` for `Pipeline.guard` to consume.""" - - def __call__(self, ctx: Context) -> None: - if not ctx.bearer_token: - raise Unauthorized("bearer required") - - try: - authn = get_authenticator().authenticate(ctx.bearer_token) - except InvalidBearerError as e: - raise Unauthorized(str(e)) - - ctx.subject_type = authn.subject_type - ctx.subject_email = authn.subject_email - ctx.subject_issuer = authn.subject_issuer - ctx.account_id = authn.account_id - ctx.scopes = frozenset(authn.scopes) - ctx.source = authn.source - ctx.token_id = authn.token_id - ctx.expires_at = authn.expires_at - ctx.token_hash = authn.token_hash - ctx.cached_verified_tenants = dict(authn.verified_tenants) - ctx.auth_ctx_reset_token = set_auth_ctx(authn) - - -class ScopeCheck: - """Verify ctx.scopes (already populated by BearerCheck) covers required.""" - - def __call__(self, ctx: Context) -> None: - if Scope.FULL in ctx.scopes or ctx.required_scope in ctx.scopes: - return - raise Forbidden("insufficient_scope") - - -class SurfaceCheck: - """Reject the request if the resolved subject is not in `accepted`.""" - - def __init__(self, *, accepted: frozenset[SubjectType]) -> None: - self._accepted = accepted - - def __call__(self, ctx: Context) -> None: - check_surface(self._accepted) - - -class AppResolver: - """Read ``app_id`` from ``ctx.path_params``; populate ctx.app + ctx.tenant. - - Every endpoint using the OAuth bearer pipeline must declare - ```` in its route — that is the design lock-in (no body / - header coupling). ``Pipeline.guard`` lifts ``request.view_args`` into - ``ctx.path_params`` at the boundary so this step doesn't need to know - about the request object. - """ - - def __call__(self, ctx: Context) -> None: - app_id = ctx.path_params.get("app_id") - if not app_id: - raise BadRequest("app_id is required in path") - app = AppService.get_app_by_id(db.session, app_id) - if not app or app.status != "normal": - raise NotFound("app not found") - if not app.enable_api: - raise Forbidden("service_api_disabled") - tenant = TenantService.get_tenant_by_id(db.session, str(app.tenant_id)) - if tenant is None or tenant.status == TenantStatus.ARCHIVE: - raise Forbidden("workspace unavailable") - ctx.app, ctx.tenant = app, tenant - - -class WorkspaceMembershipCheck: - """Layer 0 — workspace membership gate. - - CE-only (skipped when ENTERPRISE_ENABLED). Account-subject bearers - (dfoa_) only — SSO subjects skip. - """ - - def __call__(self, ctx: Context) -> None: - if dify_config.ENTERPRISE_ENABLED: - return - if ctx.subject_type != SubjectType.ACCOUNT: - return - if ctx.account_id is None or ctx.tenant is None: - raise Unauthorized("account_id or tenant unset — BearerCheck or AppResolver did not run") - if ctx.token_hash is None: - raise Unauthorized("token_hash unset — BearerCheck did not run") - - check_workspace_membership( - account_id=ctx.account_id, - tenant_id=ctx.must_tenant.id, - token_hash=ctx.token_hash, - cached_verdicts=ctx.cached_verified_tenants or {}, - ) - - -class AppAuthzCheck: - def __init__(self, resolve_strategy: Callable[[], AppAuthzStrategy]) -> None: - self._resolve = resolve_strategy - - def __call__(self, ctx: Context) -> None: - if not self._resolve().authorize(ctx): - raise Forbidden("subject_no_app_access") - - -class CallerMount: - def __init__(self, *mounters: CallerMounter) -> None: - self._mounters = mounters - - def __call__(self, ctx: Context) -> None: - if ctx.subject_type is None: - raise Unauthorized("subject_type unset — BearerCheck did not run") - for m in self._mounters: - if m.applies_to(ctx.must_subject_type): - m.mount(ctx) - return - raise Unauthorized("no caller mounter for subject type") - - -__all__ = [ - "AppAuthzCheck", - "AppResolver", - "AuthContext", - "BearerCheck", - "CallerMount", - "ScopeCheck", - "SurfaceCheck", - "WorkspaceMembershipCheck", -] diff --git a/api/controllers/openapi/auth/strategies.py b/api/controllers/openapi/auth/strategies.py deleted file mode 100644 index aaaaadd948..0000000000 --- a/api/controllers/openapi/auth/strategies.py +++ /dev/null @@ -1,168 +0,0 @@ -"""Strategy classes for the openapi auth pipeline. - -App authorization (Acl/Membership) and caller mounting (Account/EndUser) -vary along independent axes; each strategy is one class so the pipeline -composition stays a flat list. -""" - -from __future__ import annotations - -from typing import Protocol - -from flask import current_app -from flask_login import user_logged_in - -from controllers.openapi.auth.context import Context -from core.app.entities.app_invoke_entities import InvokeFrom -from extensions.ext_database import db -from libs.oauth_bearer import SubjectType -from services.account_service import AccountService, TenantService -from services.end_user_service import EndUserService -from services.enterprise.enterprise_service import ( - EnterpriseService, - WebAppAccessMode, -) - - -class AppAuthzStrategy(Protocol): - def authorize(self, ctx: Context) -> bool: ... - - -class AclStrategy: - """Per-app ACL, evaluated in two stages. - - The EE gateway has already enforced tenancy and workspace membership - by the time this strategy runs, so AclStrategy only owns per-app ACL: - - 1. Subject vs access-mode compatibility (pure rule table). External-SSO - bearers belong to public-facing apps only; account bearers cover the - full set. A mismatch is an immediate deny — no IO. - 2. For modes that pair with the subject, decide whether the inner - permission API must run. Only `PRIVATE` (per-app selected-user list) - requires it; the remaining modes are pass-through. - """ - - _ALLOWED_MODES_BY_SUBJECT: dict[SubjectType, frozenset[WebAppAccessMode]] = { - SubjectType.ACCOUNT: frozenset( - { - WebAppAccessMode.PUBLIC, - WebAppAccessMode.SSO_VERIFIED, - WebAppAccessMode.PRIVATE_ALL, - WebAppAccessMode.PRIVATE, - } - ), - SubjectType.EXTERNAL_SSO: frozenset( - { - WebAppAccessMode.PUBLIC, - WebAppAccessMode.SSO_VERIFIED, - } - ), - } - - _MODES_REQUIRING_INNER_CHECK: frozenset[WebAppAccessMode] = frozenset({WebAppAccessMode.PRIVATE}) - - def authorize(self, ctx: Context) -> bool: - if ctx.app is None: - return False - access_mode = self._fetch_access_mode(ctx.app.id) - if access_mode is None: - return False - if not self._subject_allowed_for_mode(ctx.must_subject_type, access_mode): - return False - if access_mode not in self._MODES_REQUIRING_INNER_CHECK: - return True - return self._inner_permission_check(ctx) - - @staticmethod - def _fetch_access_mode(app_id: str) -> WebAppAccessMode | None: - settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=app_id) - if settings is None: - return None - try: - return WebAppAccessMode(settings.access_mode) - except ValueError: - return None - - @classmethod - def _subject_allowed_for_mode(cls, subject_type: SubjectType, access_mode: WebAppAccessMode) -> bool: - return access_mode in cls._ALLOWED_MODES_BY_SUBJECT.get(subject_type, frozenset()) - - def _inner_permission_check(self, ctx: Context) -> bool: - if ctx.app is None: - return False - user_id = self._resolve_user_id(ctx) - if user_id is None: - return False - return EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( - user_id=user_id, - app_id=ctx.app.id, - ) - - @staticmethod - def _resolve_user_id(ctx: Context) -> str | None: - if ctx.subject_type == SubjectType.ACCOUNT: - return str(ctx.account_id) if ctx.account_id is not None else None - if ctx.subject_email is None: - return None - account = AccountService.get_account_by_email(db.session, ctx.subject_email) - return str(account.id) if account is not None else None - - -class MembershipStrategy: - """Tenant-membership fallback. - - Used when webapp-auth is disabled (CE deployment). Account-bearing - subjects pass if they have a TenantAccountJoin row; EXTERNAL_SSO is - denied (it requires the webapp-auth surface). - """ - - def authorize(self, ctx: Context) -> bool: - if ctx.subject_type == SubjectType.EXTERNAL_SSO: - return False - if ctx.tenant is None: - return False - return TenantService.account_belongs_to_tenant(db.session, ctx.account_id, ctx.tenant.id) - - -def _login_as(user) -> None: - """Set Flask-Login request user so downstream services see the caller.""" - current_app.login_manager._update_request_context_with_user(user) # type:ignore - user_logged_in.send(current_app._get_current_object(), user=user) # type:ignore - - -class CallerMounter(Protocol): - def applies_to(self, subject_type: SubjectType) -> bool: ... - - def mount(self, ctx: Context) -> None: ... - - -class AccountMounter: - def applies_to(self, subject_type: SubjectType) -> bool: - return subject_type == SubjectType.ACCOUNT - - def mount(self, ctx: Context) -> None: - if ctx.account_id is None: - raise RuntimeError("AccountMounter: account_id unset — BearerCheck did not run") - account = AccountService.get_account_by_id(db.session, str(ctx.account_id)) - if account is None: - raise RuntimeError("AccountMounter: account row missing for resolved bearer") - account.current_tenant = ctx.must_tenant - _login_as(account) - ctx.caller, ctx.caller_kind = account, "account" - - -class EndUserMounter: - def applies_to(self, subject_type: SubjectType) -> bool: - return subject_type == SubjectType.EXTERNAL_SSO - - def mount(self, ctx: Context) -> None: - if ctx.tenant is None or ctx.app is None or ctx.subject_email is None: - raise RuntimeError("EndUserMounter: tenant/app/subject_email unset — earlier steps did not run") - end_user = EndUserService.get_or_create_end_user_by_type( - InvokeFrom.OPENAPI, - tenant_id=ctx.tenant.id, - app_id=ctx.app.id, - user_id=ctx.subject_email, - ) - _login_as(end_user) - ctx.caller, ctx.caller_kind = end_user, "end_user" diff --git a/api/controllers/openapi/auth/verify.py b/api/controllers/openapi/auth/verify.py new file mode 100644 index 0000000000..22410b3374 --- /dev/null +++ b/api/controllers/openapi/auth/verify.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from werkzeug.exceptions import Forbidden, Unauthorized + +from controllers.openapi.auth.data import AuthData +from extensions.ext_database import db +from libs.oauth_bearer import Scope, TokenType, check_workspace_membership +from services.account_service import AccountService, TenantService +from services.enterprise.enterprise_service import EnterpriseService, WebAppAccessMode + + +def check_scope(data: AuthData) -> None: + if data.required_scope is None: + return + if Scope.FULL in data.scopes or data.required_scope in data.scopes: + return + raise Forbidden("insufficient_scope") + + +def check_membership(data: AuthData) -> None: + if data.tenant is None: + raise Unauthorized("tenant unset") + if data.account_id is None: + raise Unauthorized("account_id unset") + check_workspace_membership( + account_id=data.account_id, + tenant_id=data.tenant.id, + token_hash=data.token_hash, + membership_cache=data.tenants, + ) + + +def check_app_access(data: AuthData) -> None: + if data.tenant is None: + return + if not TenantService.account_belongs_to_tenant(db.session, data.account_id, data.tenant.id): + raise Forbidden("subject_no_app_access") + + +_ALLOWED_MODES_BY_TOKEN_TYPE: dict[TokenType, frozenset[WebAppAccessMode]] = { + TokenType.OAUTH_ACCOUNT: frozenset( + { + WebAppAccessMode.PUBLIC, + WebAppAccessMode.SSO_VERIFIED, + WebAppAccessMode.PRIVATE_ALL, + WebAppAccessMode.PRIVATE, + } + ), + TokenType.OAUTH_EXTERNAL_SSO: frozenset( + { + WebAppAccessMode.PUBLIC, + WebAppAccessMode.SSO_VERIFIED, + } + ), +} + + +def check_acl(data: AuthData) -> None: + if data.app is None or data.app_access_mode is None: + raise Forbidden("app or access mode not loaded") + allowed_modes = _ALLOWED_MODES_BY_TOKEN_TYPE.get(data.token_type, frozenset()) + if data.app_access_mode not in allowed_modes: + raise Forbidden("subject_not_allowed_for_access_mode") + + +def check_private_app_permission(data: AuthData) -> None: + if data.app is None: + raise Forbidden("app not loaded") + user_id = _resolve_user_id(data) + if user_id is None: + raise Forbidden("cannot resolve user for private app check") + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id=user_id, app_id=data.app.id): + raise Forbidden("user_not_allowed_for_private_app") + + +def _resolve_user_id(data: AuthData) -> str | None: + if data.token_type == TokenType.OAUTH_ACCOUNT: + return str(data.account_id) if data.account_id is not None else None + if data.external_identity is None: + return None + account = AccountService.get_account_by_email(db.session, data.external_identity.email) + return str(account.id) if account is not None else None diff --git a/api/controllers/openapi/files.py b/api/controllers/openapi/files.py index eb16015821..1a2c16abf9 100644 --- a/api/controllers/openapi/files.py +++ b/api/controllers/openapi/files.py @@ -17,11 +17,11 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from extensions.ext_database import db from fields.file_fields import FileResponse from libs.oauth_bearer import Scope -from models import Account, App from services.file_service import FileService @@ -39,8 +39,9 @@ class AppFileUploadApi(Resource): } ) @openapi_ns.response(HTTPStatus.CREATED, "File uploaded", openapi_ns.models[FileResponse.__name__]) - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, app_model: App, caller: Account, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, *, auth_data: AuthData): + app_model, caller, _ = auth_data.require_app_context() if "file" not in request.files: raise NoFileUploadedError() if len(request.files) > 1: diff --git a/api/controllers/openapi/human_input_form.py b/api/controllers/openapi/human_input_form.py index 7d54140efd..3c359406be 100644 --- a/api/controllers/openapi/human_input_form.py +++ b/api/controllers/openapi/human_input_form.py @@ -17,7 +17,8 @@ from werkzeug.exceptions import BadRequest, NotFound from controllers.common.human_input import HumanInputFormSubmitPayload, stringify_form_default_values from controllers.common.schema import register_schema_models from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from core.workflow.human_input_policy import HumanInputSurface, is_recipient_type_allowed_for_surface from extensions.ext_database import db from libs.helper import to_timestamp @@ -55,8 +56,9 @@ def _ensure_form_is_allowed_for_openapi(form) -> None: @openapi_ns.route("/apps//form/human_input/") class OpenApiWorkflowHumanInputFormApi(Resource): @openapi_ns.response(200, "Form definition") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def get(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def get(self, app_id: str, form_token: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() service = HumanInputService(db.engine) form = service.get_form_by_token(form_token) if form is None: @@ -69,8 +71,9 @@ class OpenApiWorkflowHumanInputFormApi(Resource): @openapi_ns.expect(openapi_ns.models[HumanInputFormSubmitPayload.__name__]) @openapi_ns.response(200, "Form submitted") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def post(self, app_id: str, form_token: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def post(self, app_id: str, form_token: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() payload = HumanInputFormSubmitPayload.model_validate(request.get_json(silent=True) or {}) service = HumanInputService(db.engine) diff --git a/api/controllers/openapi/workflow_events.py b/api/controllers/openapi/workflow_events.py index b14b2d400f..f21306e491 100644 --- a/api/controllers/openapi/workflow_events.py +++ b/api/controllers/openapi/workflow_events.py @@ -17,7 +17,8 @@ from sqlalchemy.orm import sessionmaker from werkzeug.exceptions import NotFound, UnprocessableEntity from controllers.openapi import openapi_ns -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -28,7 +29,7 @@ from core.workflow.human_input_policy import HumanInputSurface from extensions.ext_database import db from libs.oauth_bearer import Scope from models.enums import CreatorUserRole -from models.model import App, AppMode +from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory from services.workflow_event_snapshot_service import build_workflow_event_stream @@ -36,8 +37,9 @@ from services.workflow_event_snapshot_service import build_workflow_event_stream @openapi_ns.route("/apps//tasks//events") class OpenApiWorkflowEventsApi(Resource): @openapi_ns.response(200, "SSE event stream") - @OAUTH_BEARER_PIPELINE.guard(scope=Scope.APPS_RUN) - def get(self, app_id: str, task_id: str, app_model: App, caller, caller_kind: str): + @auth_router.guard(scope=Scope.APPS_RUN) + def get(self, app_id: str, task_id: str, *, auth_data: AuthData): + app_model, caller, caller_kind = auth_data.require_app_context() app_mode = AppMode.value_of(app_model.mode) if app_mode not in {AppMode.WORKFLOW, AppMode.ADVANCED_CHAT}: raise UnprocessableEntity("mode_not_supported_for_event_reconnect") diff --git a/api/controllers/openapi/workspaces.py b/api/controllers/openapi/workspaces.py index fa2aca7dd0..b23012a810 100644 --- a/api/controllers/openapi/workspaces.py +++ b/api/controllers/openapi/workspaces.py @@ -35,15 +35,11 @@ from controllers.openapi._models import ( WorkspaceListResponse, WorkspaceSummaryResponse, ) +from controllers.openapi.auth.composition import auth_router +from controllers.openapi.auth.data import AuthData from controllers.openapi.auth.role_gate import require_workspace_role -from controllers.openapi.auth.surface_gate import accept_subjects from extensions.ext_database import db -from libs.oauth_bearer import ( - ACCEPT_USER_ANY, - SubjectType, - get_auth_ctx, - validate_bearer, -) +from libs.oauth_bearer import Scope, TokenType from models import Account, Tenant, TenantAccountJoin from models.account import TenantAccountRole, TenantStatus from services.account_service import AccountService, RegisterService, TenantService @@ -60,11 +56,6 @@ from services.feature_service import FeatureService def _validate_body[M: BaseModel](model: type[M]) -> M: - """Validate JSON body against ``model``. Validation errors → HTTP 400. - - The workspace spec is explicit that bad email / unknown role payloads - are 400, not Pydantic's default 422 — handle uniformly here. - """ body = request.get_json(silent=True) or {} try: return model.model_validate(body) @@ -91,7 +82,6 @@ def _load_tenant(workspace_id: str) -> Tenant: def _load_account(account_id: object) -> Account: - """Load the caller's Account. Missing == auth wiring bug, not user error.""" account = AccountService.get_account_by_id(db.session, str(account_id)) if account_id else None if account is None: raise RuntimeError("authenticated account_id has no Account row") @@ -99,13 +89,6 @@ def _load_account(account_id: object) -> Account: def _quota_error(*, code: str, message: str, hint: str) -> Forbidden: - """Build a 403 with envelope ``{code, message, hint}``. - - CLI ``error-mapper`` reads ``message`` and ``hint`` off the wire body - verbatim — the structured envelope lets it surface remediation guidance - (e.g. "upgrade your plan") without the CLI needing to know edition - semantics. - """ err = Forbidden(message) err.response = make_response( jsonify({"code": code, "message": message, "hint": hint}), @@ -115,16 +98,6 @@ def _quota_error(*, code: str, message: str, hint: str) -> Forbidden: def _check_member_invite_quota(tenant_id: str) -> None: - """Edition-aware member-count gate for invite. - - Both branches self-disable on CE because ``FeatureService.get_features`` - leaves ``billing.enabled`` and ``workspace_members.enabled`` False by - default; SaaS billing API and EE license activation are what flip them on. - - Mirrors the two checks the console invite path performs (decorator at - ``console/wraps.py:106`` for billing + inline at - ``console/workspace/members.py:130`` for license). - """ features = FeatureService.get_features(tenant_id) if features.billing.enabled: @@ -148,12 +121,9 @@ def _check_member_invite_quota(tenant_id: str) -> None: @openapi_ns.route("/workspaces") class WorkspacesApi(Resource): @openapi_ns.response(200, "Workspace list", openapi_ns.models[WorkspaceListResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) - def get(self): - ctx = get_auth_ctx() - - rows = TenantService.get_workspaces_for_account(db.session, str(ctx.account_id)) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, *, auth_data: AuthData): + rows = TenantService.get_workspaces_for_account(db.session, str(auth_data.account_id)) return WorkspaceListResponse(workspaces=list(starmap(_workspace_summary, rows))).model_dump(mode="json"), 200 @@ -161,12 +131,9 @@ class WorkspacesApi(Resource): @openapi_ns.route("/workspaces/") class WorkspaceByIdApi(Resource): @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) - def get(self, workspace_id: str): - ctx = get_auth_ctx() - - row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def get(self, workspace_id: str, *, auth_data: AuthData): + row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id) # 404 (not 403) on non-member so workspace IDs don't leak across tenants. if row is None: raise NotFound("workspace not found") @@ -185,21 +152,17 @@ class WorkspaceSwitchApi(Resource): """ @openapi_ns.response(200, "Workspace detail", openapi_ns.models[WorkspaceDetailResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) @require_workspace_role() - def post(self, workspace_id: str): - ctx = get_auth_ctx() - account = _load_account(ctx.account_id) + def post(self, workspace_id: str, *, auth_data: AuthData): + account = _load_account(auth_data.account_id) try: TenantService.switch_tenant(account, workspace_id) except AccountNotLinkTenantError: - # Membership existed at gate time but Tenant.status != NORMAL or - # the row was just removed — treat as not-found. raise NotFound("workspace not found") - row = TenantService.find_workspace_for_account(db.session, str(ctx.account_id), workspace_id) + row = TenantService.find_workspace_for_account(db.session, str(auth_data.account_id), workspace_id) if row is None: raise NotFound("workspace not found") tenant, membership = row @@ -216,20 +179,15 @@ class WorkspaceMembersApi(Resource): @openapi_ns.doc(params=query_params_from_model(MemberListQuery)) @openapi_ns.response(200, "Member list", openapi_ns.models[MemberListResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) + @auth_router.guard(scope=Scope.WORKSPACE_READ, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) @require_workspace_role() - def get(self, workspace_id: str): + def get(self, workspace_id: str, *, auth_data: AuthData): try: query = MemberListQuery.model_validate(request.args.to_dict(flat=True)) except ValidationError as exc: raise BadRequest(str(exc)) tenant = _load_tenant(workspace_id) - # Members per workspace are bounded by SaaS plan caps (≤50) or EE - # license seats (low thousands worst-case), so we materialize and - # slice in-memory rather than push pagination into the service — - # matches how the rest of the service exposes member lists. members = TenantService.get_tenant_members(tenant) total = len(members) start = (query.page - 1) * query.limit @@ -244,13 +202,11 @@ class WorkspaceMembersApi(Resource): @openapi_ns.expect(openapi_ns.models[MemberInvitePayload.__name__]) @openapi_ns.response(201, "Member invited", openapi_ns.models[MemberInviteResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) - def post(self, workspace_id: str): + def post(self, workspace_id: str, *, auth_data: AuthData): payload = _validate_body(MemberInvitePayload) - ctx = get_auth_ctx() - inviter = _load_account(ctx.account_id) + inviter = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) _check_member_invite_quota(str(tenant.id)) @@ -297,12 +253,10 @@ class WorkspaceMemberApi(Resource): """ @openapi_ns.response(200, "Member removed", openapi_ns.models[MemberActionResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) - def delete(self, workspace_id: str, member_id: str): - ctx = get_auth_ctx() - operator = _load_account(ctx.account_id) + def delete(self, workspace_id: str, member_id: str, *, auth_data: AuthData): + operator = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) member = AccountService.get_account_by_id(db.session, member_id) if member is None: @@ -330,13 +284,11 @@ class WorkspaceMemberRoleApi(Resource): @openapi_ns.expect(openapi_ns.models[MemberRoleUpdatePayload.__name__]) @openapi_ns.response(200, "Role updated", openapi_ns.models[MemberActionResponse.__name__]) - @validate_bearer(accept=ACCEPT_USER_ANY) - @accept_subjects(SubjectType.ACCOUNT) + @auth_router.guard(scope=Scope.WORKSPACE_WRITE, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) @require_workspace_role(TenantAccountRole.OWNER, TenantAccountRole.ADMIN) - def put(self, workspace_id: str, member_id: str): + def put(self, workspace_id: str, member_id: str, *, auth_data: AuthData): payload = _validate_body(MemberRoleUpdatePayload) - ctx = get_auth_ctx() - operator = _load_account(ctx.account_id) + operator = _load_account(auth_data.account_id) tenant = _load_tenant(workspace_id) member = AccountService.get_account_by_id(db.session, member_id) if member is None: diff --git a/api/libs/oauth_bearer.py b/api/libs/oauth_bearer.py index 6e8678eca0..7433c6c177 100644 --- a/api/libs/oauth_bearer.py +++ b/api/libs/oauth_bearer.py @@ -43,6 +43,11 @@ class SubjectType(StrEnum): EXTERNAL_SSO = "external_sso" +class TokenType(StrEnum): + OAUTH_ACCOUNT = "oauth_account" + OAUTH_EXTERNAL_SSO = "oauth_external_sso" + + class Scope(StrEnum): """Catalog of bearer scopes recognised by the openapi surface. @@ -55,6 +60,8 @@ class Scope(StrEnum): APPS_READ = "apps:read" APPS_READ_PERMITTED_EXTERNAL = "apps:read:permitted-external" APPS_RUN = "apps:run" + WORKSPACE_READ = "workspace:read" + WORKSPACE_WRITE = "workspace:write" class Accepts(StrEnum): @@ -77,7 +84,7 @@ _SUBJECT_TO_ACCEPT: dict[SubjectType, Accepts] = { class AuthContext: """Per-request identity published via :data:`_auth_ctx_var` (see :func:`set_auth_ctx` / :func:`get_auth_ctx`). ``scopes`` / - ``subject_type`` / ``source`` come from the TokenKind, not the DB — + ``subject_type`` / ``token_type`` come from the TokenKind, not the DB — corrupt rows can't elevate scope. `verified_tenants` is a snapshot of the Layer-0 verdict cache at @@ -92,7 +99,7 @@ class AuthContext: client_id: str | None scopes: frozenset[Scope] token_id: uuid.UUID - source: str + token_type: TokenType expires_at: datetime | None token_hash: str verified_tenants: dict[str, bool] = field(default_factory=dict) @@ -180,7 +187,7 @@ class TokenKind: prefix: str subject_type: SubjectType scopes: frozenset[Scope] - source: str + token_type: TokenType resolver: Resolver def matches(self, token: str) -> bool: @@ -291,7 +298,7 @@ class BearerAuthenticator: client_id=row.client_id, scopes=kind.scopes, token_id=row.token_id, - source=kind.source, + token_type=kind.token_type, expires_at=row.expires_at, token_hash=token_hash, verified_tenants=dict(row.verified_tenants), @@ -483,7 +490,7 @@ def check_workspace_membership( account_id: uuid.UUID | str, tenant_id: str, token_hash: str, - cached_verdicts: dict[str, bool], + membership_cache: dict[str, bool], ) -> None: """Layer-0 enforcement core. Raises `Forbidden` on deny, returns on allow. @@ -492,7 +499,7 @@ def check_workspace_membership( short-circuiting on EE / SSO subjects before invoking — this function runs the membership + active-status checks unconditionally. """ - cached = cached_verdicts.get(tenant_id) + cached = membership_cache.get(tenant_id) if cached is True: return if cached is False: @@ -530,7 +537,7 @@ def require_workspace_member(ctx: AuthContext, tenant_id: str) -> None: account_id=ctx.account_id, tenant_id=tenant_id, token_hash=ctx.token_hash, - cached_verdicts=ctx.verified_tenants, + membership_cache=ctx.verified_tenants, ) @@ -664,14 +671,14 @@ def build_registry(session_factory, redis_client) -> TokenKindRegistry: prefix=account.prefix, subject_type=account.subject_type, scopes=account.scopes, - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=oauth.for_account(), ), TokenKind( prefix=external.prefix, subject_type=external.subject_type, scopes=external.scopes, - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, resolver=oauth.for_external_sso(), ), ] diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py index aa6478dd97..028d32009d 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_composition.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_composition.py @@ -1,66 +1,73 @@ -from unittest.mock import patch - -from controllers.openapi.auth.composition import OAUTH_BEARER_PIPELINE, _resolve_app_authz_strategy -from controllers.openapi.auth.pipeline import Pipeline -from controllers.openapi.auth.steps import ( - AppAuthzCheck, - AppResolver, - BearerCheck, - CallerMount, - ScopeCheck, - SurfaceCheck, - WorkspaceMembershipCheck, -) -from controllers.openapi.auth.strategies import ( - AccountMounter, - AclStrategy, - EndUserMounter, - MembershipStrategy, -) -from libs.oauth_bearer import SubjectType +from controllers.openapi.auth.composition import account_pipeline, auth_router, external_sso_pipeline +from controllers.openapi.auth.flow import When +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from libs.oauth_bearer import TokenType -def test_pipeline_is_composed(): - assert isinstance(OAUTH_BEARER_PIPELINE, Pipeline) +def test_account_pipeline_is_auth_pipeline(): + assert isinstance(account_pipeline, AuthPipeline) -def test_pipeline_step_order(): - """BearerCheck → SurfaceCheck → ScopeCheck → AppResolver → - WorkspaceMembershipCheck → AppAuthzCheck → CallerMount. - SurfaceCheck enforces the dfoa_/dfoe_ surface split + emits - `openapi.wrong_surface_denied`. Rate-limit is enforced inside - `BearerAuthenticator.authenticate`, not as a separate pipeline step.""" - steps = OAUTH_BEARER_PIPELINE._steps - assert isinstance(steps[0], BearerCheck) - assert isinstance(steps[1], SurfaceCheck) - assert isinstance(steps[2], ScopeCheck) - assert isinstance(steps[3], AppResolver) - assert isinstance(steps[4], WorkspaceMembershipCheck) - assert isinstance(steps[5], AppAuthzCheck) - assert isinstance(steps[6], CallerMount) +def test_external_sso_pipeline_is_auth_pipeline(): + assert isinstance(external_sso_pipeline, AuthPipeline) -def test_pipeline_surface_check_accepts_account_only(): - """Current pipeline serves /apps//run — account surface only.""" - surface = OAUTH_BEARER_PIPELINE._steps[1] - assert isinstance(surface, SurfaceCheck) - assert surface._accepted == frozenset({SubjectType.ACCOUNT}) +def test_auth_router_is_pipeline_router(): + assert isinstance(auth_router, PipelineRouter) -def test_caller_mount_has_both_mounters(): - cm = OAUTH_BEARER_PIPELINE._steps[6] - kinds = {type(m) for m in cm._mounters} - assert AccountMounter in kinds - assert EndUserMounter in kinds +def test_account_pipeline_prepare_has_four_entries(): + assert len(account_pipeline._prepare) == 4 -@patch("controllers.openapi.auth.composition.FeatureService") -def test_strategy_resolver_picks_acl_when_enabled(fs): - fs.get_system_features.return_value.webapp_auth.enabled = True - assert isinstance(_resolve_app_authz_strategy(), AclStrategy) +def test_account_auth_list_has_five_entries(): + assert len(account_pipeline._auth) == 5 -@patch("controllers.openapi.auth.composition.FeatureService") -def test_strategy_resolver_picks_membership_when_disabled(fs): - fs.get_system_features.return_value.webapp_auth.enabled = False - assert isinstance(_resolve_app_authz_strategy(), MembershipStrategy) +def test_external_sso_pipeline_prepare_has_four_entries(): + assert len(external_sso_pipeline._prepare) == 4 + + +def test_external_sso_auth_list_has_three_entries(): + assert len(external_sso_pipeline._auth) == 3 + + +def test_account_pipeline_has_unconditional_load_account(): + non_when = [s for s in account_pipeline._prepare if not isinstance(s, When)] + assert len(non_when) == 1 + + +def test_external_sso_pipeline_all_prepare_entries_are_when(): + assert all(isinstance(s, When) for s in external_sso_pipeline._prepare) + + +def test_first_auth_entry_is_check_scope_in_both_pipelines(): + assert not isinstance(account_pipeline._auth[0], When) + assert not isinstance(external_sso_pipeline._auth[0], When) + + +def test_remaining_auth_entries_are_when_for_account(): + assert all(isinstance(s, When) for s in account_pipeline._auth[1:]) + + +def test_remaining_auth_entries_are_when_for_external_sso(): + assert all(isinstance(s, When) for s in external_sso_pipeline._auth[1:]) + + +def test_router_routes_contain_both_token_types(): + assert TokenType.OAUTH_ACCOUNT in auth_router._routes + assert TokenType.OAUTH_EXTERNAL_SSO in auth_router._routes + + +def test_external_sso_route_has_ee_required_edition(): + route = auth_router._routes[TokenType.OAUTH_EXTERNAL_SSO] + assert isinstance(route, PipelineRoute) + from controllers.openapi.auth.data import Edition + + assert route.required_edition == frozenset({Edition.EE}) + + +def test_account_route_has_no_required_edition(): + route = auth_router._routes[TokenType.OAUTH_ACCOUNT] + assert isinstance(route, PipelineRoute) + assert route.required_edition is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py b/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py new file mode 100644 index 0000000000..8367933984 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_conditions.py @@ -0,0 +1,143 @@ +from unittest.mock import MagicMock, patch + +from controllers.openapi.auth.conditions import ( + EDITION_CE, + EDITION_EE, + EDITION_SAAS, + LOADED_APP_IS_PRIVATE, + PATH_HAS_APP_ID, + TOKEN_IS_OAUTH_ACCOUNT, + TOKEN_IS_OAUTH_EXTERNAL_SSO, + WEBAPP_AUTH_ENABLED, + Cond, + config_cond, + data_cond, + request_cond, +) +from controllers.openapi.auth.data import AuthData, Edition, RequestContext +from libs.oauth_bearer import TokenType +from services.enterprise.enterprise_service import WebAppAccessMode + + +def _ctx(token_type=TokenType.OAUTH_ACCOUNT, path_params=None): + return RequestContext( + token_type=token_type, + path_params=path_params or {}, + ) + + +def _data(**kwargs): + defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "x", "scopes": frozenset()} + defaults.update(kwargs) + return AuthData(**defaults) + + +def test_and_both_true(): + a = Cond(lambda ctx, _: True) + b = Cond(lambda ctx, _: True) + assert (a & b)(_ctx()) is True + + +def test_and_one_false(): + a = Cond(lambda ctx, _: True) + b = Cond(lambda ctx, _: False) + assert (a & b)(_ctx()) is False + + +def test_or_one_true(): + a = Cond(lambda ctx, _: False) + b = Cond(lambda ctx, _: True) + assert (a | b)(_ctx()) is True + + +def test_or_both_false(): + a = Cond(lambda ctx, _: False) + b = Cond(lambda ctx, _: False) + assert (a | b)(_ctx()) is False + + +def test_invert(): + a = Cond(lambda ctx, _: True) + assert (~a)(_ctx()) is False + + +def test_chain_and_or(): + always_true = Cond(lambda ctx, _: True) + always_false = Cond(lambda ctx, _: False) + assert ((always_true | always_false) & always_true)(_ctx()) is True + + +def test_request_cond_ignores_data(): + c = request_cond(lambda ctx: ctx.token_type == TokenType.OAUTH_ACCOUNT) + assert c(_ctx(TokenType.OAUTH_ACCOUNT)) is True + assert c(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False + + +def test_data_cond_returns_false_when_data_none(): + c = data_cond(lambda data: True) + assert c(_ctx(), None) is False + + +def test_data_cond_evaluates_when_data_present(): + c = data_cond(lambda data: data.token_hash == "secret") + assert c(_ctx(), _data(token_hash="secret")) is True + assert c(_ctx(), _data(token_hash="other")) is False + + +def test_config_cond_ignores_ctx_and_data(): + c = config_cond(lambda: True) + assert c(_ctx()) is True + c2 = config_cond(lambda: False) + assert c2(_ctx(), _data()) is False + + +def test_token_is_oauth_account(): + assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_ACCOUNT)) is True + assert TOKEN_IS_OAUTH_ACCOUNT(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is False + + +def test_token_is_oauth_external_sso(): + assert TOKEN_IS_OAUTH_EXTERNAL_SSO(_ctx(TokenType.OAUTH_EXTERNAL_SSO)) is True + + +def test_path_has_app_id_true(): + assert PATH_HAS_APP_ID(_ctx(path_params={"app_id": "abc"})) is True + + +def test_path_has_app_id_false(): + assert PATH_HAS_APP_ID(_ctx(path_params={})) is False + + +def test_edition_ce(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.CE): + assert EDITION_CE(_ctx()) is True + assert EDITION_EE(_ctx()) is False + assert EDITION_SAAS(_ctx()) is False + + +def test_edition_ee(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.EE): + assert EDITION_EE(_ctx()) is True + assert EDITION_CE(_ctx()) is False + + +def test_edition_saas(): + with patch("controllers.openapi.auth.conditions.current_edition", return_value=Edition.SAAS): + assert EDITION_SAAS(_ctx()) is True + + +def test_webapp_auth_enabled(): + mock_features = MagicMock() + mock_features.webapp_auth.enabled = True + with patch("controllers.openapi.auth.conditions.FeatureService.get_system_features", return_value=mock_features): + assert WEBAPP_AUTH_ENABLED(_ctx()) is True + + +def test_loaded_app_is_private(): + data_private = _data(app_access_mode=WebAppAccessMode.PRIVATE) + data_public = _data(app_access_mode=WebAppAccessMode.PUBLIC) + data_none = _data(app_access_mode=None) + assert LOADED_APP_IS_PRIVATE(_ctx(), data_private) is True + assert LOADED_APP_IS_PRIVATE(_ctx(), data_public) is False + assert LOADED_APP_IS_PRIVATE(_ctx(), data_none) is False + assert LOADED_APP_IS_PRIVATE(_ctx(), None) is False diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_context.py b/api/tests/unit_tests/controllers/openapi/auth/test_context.py deleted file mode 100644 index cc9c011342..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_context.py +++ /dev/null @@ -1,21 +0,0 @@ -from controllers.openapi.auth.context import Context - - -def test_context_starts_unpopulated(): - ctx = Context(required_scope="apps:run") - assert ctx.bearer_token is None - assert ctx.path_params == {} - assert ctx.subject_type is None - assert ctx.subject_email is None - assert ctx.account_id is None - assert ctx.scopes == frozenset() - assert ctx.app is None - assert ctx.tenant is None - assert ctx.caller is None - assert ctx.caller_kind is None - - -def test_context_fields_are_mutable(): - ctx = Context(required_scope="apps:run") - ctx.scopes = frozenset({"full"}) - assert "full" in ctx.scopes diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_data.py b/api/tests/unit_tests/controllers/openapi/auth/test_data.py new file mode 100644 index 0000000000..c39ed9c6d0 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_data.py @@ -0,0 +1,117 @@ +import uuid +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from controllers.openapi.auth.data import ( + AuthData, + Edition, + ExternalIdentity, + RequestContext, + current_edition, +) +from libs.oauth_bearer import Scope, TokenType + + +def test_current_edition_saas(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "CLOUD" + cfg.ENTERPRISE_ENABLED = True + assert current_edition() == Edition.SAAS + + +def test_current_edition_ee(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "SELF_HOSTED" + cfg.ENTERPRISE_ENABLED = True + assert current_edition() == Edition.EE + + +def test_current_edition_ce(): + with patch("controllers.openapi.auth.data.dify_config") as cfg: + cfg.EDITION = "SELF_HOSTED" + cfg.ENTERPRISE_ENABLED = False + assert current_edition() == Edition.CE + + +def test_external_identity_frozen(): + ei = ExternalIdentity(email="a@b.com", issuer="idp") + with pytest.raises(ValidationError): + ei.email = "other@b.com" # type: ignore[misc] + + +def test_external_identity_issuer_optional(): + ei = ExternalIdentity(email="a@b.com") + assert ei.issuer is None + + +def test_request_context_frozen(): + ctx = RequestContext( + token_type=TokenType.OAUTH_ACCOUNT, + path_params={"app_id": "123"}, + ) + with pytest.raises(ValidationError): + ctx.token_type = TokenType.OAUTH_EXTERNAL_SSO # type: ignore[misc] + + +def test_request_context_scope_optional(): + ctx = RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}) + assert ctx.scope is None + + +def test_auth_data_is_mutable(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset({Scope.FULL}), + ) + data.token_type = TokenType.OAUTH_EXTERNAL_SSO + assert data.token_type == TokenType.OAUTH_EXTERNAL_SSO + + +def test_auth_data_path_params_defaults_empty(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.path_params == {} + + +def test_auth_data_account_id_optional(): + data = AuthData( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + token_hash="abc", + scopes=frozenset({Scope.APPS_RUN}), + external_identity=ExternalIdentity(email="u@sso.com"), + ) + assert data.account_id is None + + +def test_auth_data_external_identity_none_for_account(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="abc", + scopes=frozenset({Scope.FULL}), + ) + assert data.external_identity is None + + +def test_auth_data_tenants_default_empty(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.tenants == {} + + +def test_auth_data_token_id_optional(): + data = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + token_hash="abc", + scopes=frozenset(), + ) + assert data.token_id is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_flow.py b/api/tests/unit_tests/controllers/openapi/auth/test_flow.py new file mode 100644 index 0000000000..3ea7ac2b12 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_flow.py @@ -0,0 +1,42 @@ +import inspect + +from controllers.openapi.auth.conditions import Cond +from controllers.openapi.auth.data import AuthData, RequestContext +from controllers.openapi.auth.flow import When +from libs.oauth_bearer import TokenType + + +def _ctx(): + return RequestContext(token_type=TokenType.OAUTH_ACCOUNT, path_params={}) + + +def _data(): + return AuthData(token_type=TokenType.OAUTH_ACCOUNT, token_hash="x", scopes=frozenset()) + + +def test_applies_returns_true_when_condition_true(): + w = When(Cond(lambda ctx, _: True), then=lambda b: None) + assert w.applies(_ctx()) is True + + +def test_applies_returns_false_when_condition_false(): + w = When(Cond(lambda ctx, _: False), then=lambda b: None) + assert w.applies(_ctx()) is False + + +def test_applies_with_data(): + w = When(Cond(lambda ctx, data: data is not None), then=lambda b: None) + assert w.applies(_ctx(), _data()) is True + assert w.applies(_ctx(), None) is False + + +def test_call_invokes_step(): + calls = [] + w = When(Cond(lambda ctx, _: True), then=lambda arg: calls.append(arg)) + w("payload") + assert calls == ["payload"] + + +def test_then_is_keyword_only(): + sig = inspect.signature(When.__init__) + assert sig.parameters["then"].kind.name == "KEYWORD_ONLY" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py index 15538275f5..a92f90112f 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_pipeline.py @@ -1,59 +1,269 @@ +import uuid +from unittest.mock import MagicMock, patch + import pytest from flask import Flask +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.data import AuthData, Edition +from controllers.openapi.auth.pipeline import AuthPipeline, PipelineRoute, PipelineRouter +from libs.oauth_bearer import Scope, TokenType -def test_run_invokes_each_step_in_order(): - calls = [] - - class S: - def __init__(self, tag): - self.tag = tag - - def __call__(self, ctx): - calls.append(self.tag) - - Pipeline(S("a"), S("b"), S("c")).run(Context(required_scope="x")) - assert calls == ["a", "b", "c"] +def _make_identity( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=None, + scopes=None, + token_hash="testhash", + subject_email=None, + subject_issuer=None, + verified_tenants=None, + token_id=None, +): + identity = MagicMock() + identity.token_type = token_type + identity.account_id = account_id or uuid.uuid4() + identity.scopes = scopes or frozenset({Scope.FULL}) + identity.token_hash = token_hash + identity.subject_email = subject_email + identity.subject_issuer = subject_issuer + identity.verified_tenants = verified_tenants or {} + identity.token_id = token_id or uuid.uuid4() + return identity -def test_run_short_circuits_on_raise(): - calls = [] - - class Boom: - def __call__(self, ctx): - raise RuntimeError("boom") - - class Tail: - def __call__(self, ctx): - calls.append("ran") - - with pytest.raises(RuntimeError): - Pipeline(Boom(), Tail()).run(Context(required_scope="x")) - assert calls == [] +@pytest.fixture +def app(): + return Flask(__name__) -def test_guard_decorator_runs_pipeline_and_unpacks_handler_kwargs(): - seen = {} +def _make_router(token_type=TokenType.OAUTH_ACCOUNT, prepare=None, auth=None): + pipeline = AuthPipeline(prepare=prepare or [], auth=auth or []) + return PipelineRouter({token_type: PipelineRoute(pipeline)}) - class FakeStep: - def __call__(self, ctx): - ctx.app = "APP" - ctx.caller = "CALLER" - ctx.caller_kind = "account" - pipeline = Pipeline(FakeStep()) +def _fake_identity(): + return _make_identity() - @pipeline.guard(scope="apps:run") - def handler(app_model, caller, caller_kind): - seen["app_model"] = app_model - seen["caller"] = caller - seen["caller_kind"] = caller_kind - return "ok" - app = Flask(__name__) - with app.test_request_context("/x", method="POST"): - assert handler() == "ok" - assert seen == {"app_model": "APP", "caller": "CALLER", "caller_kind": "account"} +# --- PipelineRouter.guard --- + + +def test_guard_passes_auth_data_to_view(app): + router = _make_router() + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert isinstance(received["data"], AuthData) + + +def test_guard_edition_gate_returns_404(app): + router = _make_router() + + with app.test_request_context("/test"): + with patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE): + + @router.guard(scope=Scope.FULL, edition=frozenset({Edition.EE})) + def view(*, auth_data): + pass + + with pytest.raises(NotFound): + view() + + +def test_guard_token_type_gate_returns_403(app): + router = _make_router() + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.emit_wrong_surface"), + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _fake_identity() + identity.token_type = TokenType.OAUTH_EXTERNAL_SSO + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_unregistered_token_type_returns_403(app): + router = _make_router(token_type=TokenType.OAUTH_ACCOUNT) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _fake_identity() + identity.token_type = TokenType.OAUTH_EXTERNAL_SSO + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_no_bearer_returns_401(app): + router = _make_router() + + with app.test_request_context("/test"): + with patch("controllers.openapi.auth.pipeline.extract_bearer", return_value=None): + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + with pytest.raises(Unauthorized): + view() + + +def test_guard_runs_prepare_steps_in_order(app): + order = [] + + def p1(b): + order.append("p1") + + def p2(b): + order.append("p2") + + router = _make_router(prepare=[p1, p2]) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + pass + + view() + + assert order == ["p1", "p2"] + + +def test_guard_resets_auth_ctx_on_exception(app): + router = _make_router() + reset_called = [] + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value="tok"), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx", side_effect=lambda t: reset_called.append(t)), + ): + mock_auth.return_value.authenticate.return_value = _fake_identity() + + @router.guard(scope=Scope.FULL) + def view(*, auth_data): + raise RuntimeError("boom") + + with pytest.raises(RuntimeError): + view() + + assert reset_called == ["tok"] + + +def test_router_rejects_token_type_on_wrong_edition(app): + pipeline = AuthPipeline(prepare=[], auth=[]) + route = PipelineRoute(pipeline, required_edition=frozenset({Edition.EE})) + router = PipelineRouter({TokenType.OAUTH_EXTERNAL_SSO: route}) + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.current_edition", return_value=Edition.CE), + ): + identity = _make_identity(token_type=TokenType.OAUTH_EXTERNAL_SSO) + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.APPS_RUN) + def view(*, auth_data): + pass + + with pytest.raises(Forbidden): + view() + + +def test_guard_populates_external_identity_from_subject_email(app): + from controllers.openapi.auth.data import ExternalIdentity + + router = _make_router(token_type=TokenType.OAUTH_EXTERNAL_SSO) + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + identity = _make_identity( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + subject_email="user@sso.com", + subject_issuer="https://idp.example.com", + ) + mock_auth.return_value.authenticate.return_value = identity + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_EXTERNAL_SSO})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert isinstance(received["data"].external_identity, ExternalIdentity) + assert received["data"].external_identity.email == "user@sso.com" + assert received["data"].external_identity.issuer == "https://idp.example.com" + + +def test_guard_no_external_identity_when_subject_email_absent(app): + router = _make_router() + received = {} + + with app.test_request_context("/test", headers={"Authorization": "Bearer tok"}): + with ( + patch("controllers.openapi.auth.pipeline.extract_bearer", return_value="tok"), + patch("controllers.openapi.auth.pipeline.get_authenticator") as mock_auth, + patch("controllers.openapi.auth.pipeline.set_auth_ctx", return_value=MagicMock()), + patch("controllers.openapi.auth.pipeline.reset_auth_ctx"), + ): + mock_auth.return_value.authenticate.return_value = _make_identity(subject_email=None) + + @router.guard(scope=Scope.FULL, allowed_token_types=frozenset({TokenType.OAUTH_ACCOUNT})) + def view(*, auth_data): + received["data"] = auth_data + + view() + + assert received["data"].external_identity is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py b/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py new file mode 100644 index 0000000000..39d8aafa0e --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_prepare.py @@ -0,0 +1,183 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, NotFound, Unauthorized + +from controllers.openapi.auth.data import AuthData, ExternalIdentity +from controllers.openapi.auth.prepare import ( + load_account, + load_app, + load_app_access_mode, + load_tenant, + resolve_external_user, +) +from libs.oauth_bearer import TokenType + + +def _make_auth_data(**kwargs) -> AuthData: + mock_fields = {k: kwargs.pop(k) for k in ("app", "tenant", "caller") if k in kwargs} + data = AuthData( + token_type=kwargs.pop("token_type", TokenType.OAUTH_ACCOUNT), + token_hash=kwargs.pop("token_hash", "testhash"), + scopes=kwargs.pop("scopes", frozenset()), + **kwargs, + ) + for k, v in mock_fields.items(): + setattr(data, k, v) + return data + + +def test_load_app_writes_app_to_data(): + app = MagicMock() + app.status = "normal" + app.enable_api = True + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + load_app(data) + assert data.app is app + + +def test_load_app_raises_not_found_when_missing(): + data = _make_auth_data(path_params={"app_id": "missing"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=None): + with pytest.raises(NotFound): + load_app(data) + + +def test_load_app_raises_not_found_when_not_normal(): + app = MagicMock() + app.status = "archived" + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + with pytest.raises(NotFound): + load_app(data) + + +def test_load_app_raises_forbidden_when_api_disabled(): + app = MagicMock() + app.status = "normal" + app.enable_api = False + data = _make_auth_data(path_params={"app_id": "abc"}) + with patch("controllers.openapi.auth.prepare.AppService.get_app_by_id", return_value=app): + with pytest.raises(Forbidden): + load_app(data) + + +def test_load_tenant_writes_tenant(): + app = MagicMock() + app.tenant_id = uuid.uuid4() + tenant = MagicMock() + tenant.status = "normal" + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant): + load_tenant(data) + assert data.tenant is tenant + + +def test_load_tenant_raises_forbidden_when_archived(): + from models.account import TenantStatus + + app = MagicMock() + app.tenant_id = uuid.uuid4() + tenant = MagicMock() + tenant.status = TenantStatus.ARCHIVE + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=tenant): + with pytest.raises(Forbidden): + load_tenant(data) + + +def test_load_tenant_raises_forbidden_when_missing(): + app = MagicMock() + app.tenant_id = uuid.uuid4() + data = _make_auth_data(app=app) + with patch("controllers.openapi.auth.prepare.TenantService.get_tenant_by_id", return_value=None): + with pytest.raises(Forbidden): + load_tenant(data) + + +def test_load_tenant_raises_500_when_app_not_loaded(): + from werkzeug.exceptions import InternalServerError + + data = _make_auth_data() + with pytest.raises(InternalServerError): + load_tenant(data) + + +def test_load_account_writes_caller(): + account = MagicMock() + account_id = uuid.uuid4() + data = _make_auth_data(account_id=account_id) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account): + load_account(data) + assert data.caller is account + assert data.caller_kind == "account" + + +def test_load_account_sets_current_tenant_when_tenant_present(): + account = MagicMock() + tenant = MagicMock() + data = _make_auth_data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=account): + load_account(data) + assert account.current_tenant is tenant + + +def test_load_account_raises_unauthorized_when_not_found(): + data = _make_auth_data(account_id=uuid.uuid4()) + with patch("controllers.openapi.auth.prepare.AccountService.get_account_by_id", return_value=None): + with pytest.raises(Unauthorized): + load_account(data) + + +def test_resolve_external_user_writes_caller(): + tenant = MagicMock() + app = MagicMock() + end_user = MagicMock() + ext = ExternalIdentity(email="user@sso.com") + data = _make_auth_data(tenant=tenant, app=app, external_identity=ext) + with patch("controllers.openapi.auth.prepare.EndUserService.get_or_create_end_user_by_type", return_value=end_user): + resolve_external_user(data) + assert data.caller is end_user + assert data.caller_kind == "end_user" + + +def test_resolve_external_user_raises_unauthorized_when_context_missing(): + data = _make_auth_data(tenant=None, app=MagicMock(), external_identity=ExternalIdentity(email="u@s.com")) + with pytest.raises(Unauthorized): + resolve_external_user(data) + + +def test_load_app_access_mode_writes_mode(): + from services.enterprise.enterprise_service import WebAppAccessMode + + app = MagicMock() + app.id = "app-1" + settings = MagicMock() + settings.access_mode = "public" + data = _make_auth_data(app=app) + with patch( + "controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + return_value=settings, + ): + load_app_access_mode(data) + assert data.app_access_mode == WebAppAccessMode.PUBLIC + + +def test_load_app_access_mode_writes_none_when_value_error(): + app = MagicMock() + app.id = "app-1" + data = _make_auth_data(app=app) + with patch( + "controllers.openapi.auth.prepare.EnterpriseService.WebAppAuth.get_app_access_mode_by_id", + side_effect=ValueError("No data found."), + ): + load_app_access_mode(data) + assert data.app_access_mode is None + + +def test_load_app_access_mode_no_op_when_app_missing(): + data = _make_auth_data() + load_app_access_mode(data) + assert data.app_access_mode is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py b/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py index 9befc7dad3..68b436e824 100644 --- a/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py +++ b/api/tests/unit_tests/controllers/openapi/auth/test_role_gate.py @@ -26,7 +26,7 @@ from flask import Flask from werkzeug.exceptions import Forbidden, NotFound from controllers.openapi.auth.role_gate import require_workspace_role -from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx from models.account import TenantAccountRole # Tokens from `_seed`'s `set_auth_ctx` calls, drained after each test so a @@ -55,7 +55,7 @@ def _account_ctx(account_id: uuid.UUID | None = None) -> AuthContext: client_id="difyctl", scopes=frozenset({Scope.FULL}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=datetime.now(UTC), token_hash="h1", verified_tenants={}, @@ -71,7 +71,7 @@ def _sso_ctx() -> AuthContext: client_id="difyctl", scopes=frozenset({Scope.APPS_RUN}), token_id=uuid.uuid4(), - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, expires_at=datetime.now(UTC), token_hash="h2", verified_tenants={}, diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py deleted file mode 100644 index f051f1a71c..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_app_resolver.py +++ /dev/null @@ -1,64 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import BadRequest, Forbidden, NotFound - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import AppResolver -from models import TenantStatus - - -def _ctx(path_params: dict[str, str] | None) -> Context: - return Context(required_scope="apps:run", path_params=path_params or {}) - - -def _app(*, status="normal", enable_api=True): - return SimpleNamespace(id="app1", tenant_id="t1", status=status, enable_api=enable_api) - - -def _tenant(*, status=TenantStatus.NORMAL): - return SimpleNamespace(id="t1", status=status) - - -def test_resolver_rejects_missing_path_param(): - with pytest.raises(BadRequest): - AppResolver()(_ctx({})) - - -def test_resolver_rejects_empty_path_params(): - # `Pipeline.guard` always seeds an empty dict when Flask reports no - # view args, so a missing `app_id` key surfaces here as BadRequest. - with pytest.raises(BadRequest): - AppResolver()(_ctx(None)) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_404_when_app_missing(db): - db.session.get.side_effect = [None] - with pytest.raises(NotFound): - AppResolver()(_ctx({"app_id": "x"})) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_403_when_disabled(db): - db.session.get.side_effect = [_app(enable_api=False)] - with pytest.raises(Forbidden) as exc: - AppResolver()(_ctx({"app_id": "x"})) - assert "service_api_disabled" in str(exc.value.description) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_403_when_tenant_archived(db): - db.session.get.side_effect = [_app(), _tenant(status=TenantStatus.ARCHIVE)] - with pytest.raises(Forbidden): - AppResolver()(_ctx({"app_id": "x"})) - - -@patch("controllers.openapi.auth.steps.db") -def test_resolver_populates_app_and_tenant(db): - db.session.get.side_effect = [_app(), _tenant()] - ctx = _ctx({"app_id": "x"}) - AppResolver()(ctx) - assert ctx.app.id == "app1" - assert ctx.tenant.id == "t1" diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py deleted file mode 100644 index 6a5933da3b..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_authz.py +++ /dev/null @@ -1,76 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import AppAuthzCheck -from controllers.openapi.auth.strategies import AclStrategy, MembershipStrategy -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id="acc1"): - c = Context(required_scope="apps:run") - c.subject_type = subject_type - c.subject_email = "alice@example.com" - c.account_id = account_id - c.app = SimpleNamespace(id="app1") - c.tenant = SimpleNamespace(id="t1") - return c - - -@patch("controllers.openapi.auth.strategies.EnterpriseService") -def test_acl_strategy_private_calls_inner_api(ent): - ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode="private") - ent.WebAppAuth.is_user_allowed_to_access_webapp.return_value = True - assert AclStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True - ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_called_once_with( - user_id="acc1", - app_id="app1", - ) - - -@pytest.mark.parametrize( - ("access_mode", "subject_type", "expected"), - [ - ("public", SubjectType.ACCOUNT, True), - ("public", SubjectType.EXTERNAL_SSO, True), - ("sso_verified", SubjectType.ACCOUNT, True), - ("sso_verified", SubjectType.EXTERNAL_SSO, True), - ("private_all", SubjectType.ACCOUNT, True), - ("private_all", SubjectType.EXTERNAL_SSO, False), - ("private", SubjectType.EXTERNAL_SSO, False), - ], -) -@patch("controllers.openapi.auth.strategies.EnterpriseService") -def test_acl_strategy_subject_mode_matrix(ent, access_mode, subject_type, expected): - """Step 1 matrix: subject vs access-mode compatibility. No inner API call expected.""" - ent.WebAppAuth.get_app_access_mode_by_id.return_value = SimpleNamespace(access_mode=access_mode) - account_id = "acc1" if subject_type == SubjectType.ACCOUNT else None - assert AclStrategy().authorize(_ctx(subject_type=subject_type, account_id=account_id)) is expected - ent.WebAppAuth.is_user_allowed_to_access_webapp.assert_not_called() - - -@patch("controllers.openapi.auth.strategies.TenantService.account_belongs_to_tenant") -@patch("controllers.openapi.auth.strategies.db") -def test_membership_strategy_uses_join_lookup(db_mock, member): - member.return_value = True - assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.ACCOUNT)) is True - member.assert_called_once_with(db_mock.session, "acc1", "t1") - - -def test_membership_strategy_rejects_external_sso(): - assert MembershipStrategy().authorize(_ctx(subject_type=SubjectType.EXTERNAL_SSO, account_id=None)) is False - - -def test_app_authz_check_raises_when_strategy_denies(): - deny = SimpleNamespace(authorize=lambda c: False) - with pytest.raises(Forbidden) as exc: - AppAuthzCheck(lambda: deny)(_ctx(subject_type=SubjectType.ACCOUNT)) - assert "subject_no_app_access" in str(exc.value.description) - - -def test_app_authz_check_passes_when_strategy_allows(): - allow = SimpleNamespace(authorize=lambda c: True) - AppAuthzCheck(lambda: allow)(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py deleted file mode 100644 index 329f158f30..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_bearer.py +++ /dev/null @@ -1,83 +0,0 @@ -import uuid -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest -from flask import Flask -from werkzeug.exceptions import Unauthorized - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import BearerCheck -from libs.oauth_bearer import ( - AuthContext, - InvalidBearerError, - Scope, - SubjectType, - reset_auth_ctx, - try_get_auth_ctx, -) - - -def _ctx(bearer_token: str | None) -> Context: - return Context(required_scope="apps:run", bearer_token=bearer_token) - - -def test_bearer_check_rejects_missing_header(): - app = Flask(__name__) - with app.test_request_context(), pytest.raises(Unauthorized): - BearerCheck()(_ctx(None)) - - -@patch("controllers.openapi.auth.steps.get_authenticator") -def test_bearer_check_rejects_unknown_prefix(get_auth): - get_auth.return_value.authenticate.side_effect = InvalidBearerError("invalid_bearer") - app = Flask(__name__) - with app.test_request_context(), pytest.raises(Unauthorized): - BearerCheck()(_ctx("xxx_abc")) - - -@patch("controllers.openapi.auth.steps.get_authenticator") -def test_bearer_check_populates_context_and_publishes_auth_ctx(get_auth): - tok_id = uuid.uuid4() - authn = AuthContext( - subject_type=SubjectType.ACCOUNT, - subject_email="a@x.com", - subject_issuer=None, - account_id=None, - client_id="difyctl", - scopes=frozenset({Scope.FULL}), - token_id=tok_id, - source="oauth-account", - expires_at=datetime.now(UTC), - token_hash="hash-1", - verified_tenants={}, - ) - get_auth.return_value.authenticate.return_value = authn - - app = Flask(__name__) - ctx = _ctx("dfoa_abc") - with app.test_request_context(): - BearerCheck()(ctx) - try: - assert ctx.subject_type == SubjectType.ACCOUNT - assert ctx.subject_email == "a@x.com" - assert ctx.scopes == frozenset({Scope.FULL}) - assert ctx.source == "oauth-account" - assert ctx.token_id == tok_id - assert ctx.token_hash == "hash-1" - # BearerCheck must also publish the same identity on the - # openapi auth ContextVar so the surface gate + downstream - # handlers don't see two different identity sources between - # the decorator + pipeline paths. The reset token is parked - # on `ctx.auth_ctx_reset_token` for `Pipeline.guard` to - # consume in its `finally`. - published = try_get_auth_ctx() - assert published is authn - assert published.client_id == "difyctl" - assert ctx.auth_ctx_reset_token is not None - finally: - # In production `Pipeline.guard` resets the ContextVar; in - # this isolated step-level test we reset it ourselves so the - # value doesn't leak into the next test on the same worker. - assert ctx.auth_ctx_reset_token is not None - reset_auth_ctx(ctx.auth_ctx_reset_token) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py deleted file mode 100644 index 82ea07d736..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_layer0.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Unit tests for WorkspaceMembershipCheck (Layer 0).""" - -from __future__ import annotations - -import uuid -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import WorkspaceMembershipCheck -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id, tenant_id, cached_verified_tenants=None, token_hash=None) -> Context: - c = Context(required_scope="apps:read") - c.subject_type = subject_type - c.account_id = account_id - c.tenant = SimpleNamespace(id=tenant_id) if tenant_id else None - c.cached_verified_tenants = cached_verified_tenants - c.token_hash = token_hash - return c - - -@pytest.fixture -def step(): - return WorkspaceMembershipCheck() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_skips_when_enterprise_enabled(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = True - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id=str(uuid.uuid4()), - tenant_id=str(uuid.uuid4()), - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_skips_for_external_sso(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.EXTERNAL_SSO, - account_id=None, - tenant_id=str(uuid.uuid4()), - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_uses_cached_ok(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={"t1": True}, - token_hash="hash-1", - ) - step(ctx) - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_uses_cached_denied(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={"t1": False}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_db.session.execute.assert_not_called() - mock_record.assert_not_called() - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_denies_when_no_membership(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.return_value.scalar_one_or_none.return_value = None - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_record.assert_called_once_with("hash-1", "t1", False) - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_denies_when_account_inactive(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.side_effect = [ - MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), - MagicMock(scalar_one_or_none=MagicMock(return_value="banned")), - ] - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - with pytest.raises(Forbidden, match="workspace_membership_revoked"): - step(ctx) - mock_record.assert_called_once_with("hash-1", "t1", False) - - -@patch("controllers.openapi.auth.steps.dify_config") -@patch("libs.oauth_bearer.record_layer0_verdict") -@patch("libs.oauth_bearer.db") -def test_allows_active_member(mock_db, mock_record, mock_cfg, step): - mock_cfg.ENTERPRISE_ENABLED = False - mock_db.session.execute.side_effect = [ - MagicMock(scalar_one_or_none=MagicMock(return_value="join-id")), - MagicMock(scalar_one_or_none=MagicMock(return_value="active")), - ] - ctx = _ctx( - subject_type=SubjectType.ACCOUNT, - account_id="a1", - tenant_id="t1", - cached_verified_tenants={}, - token_hash="hash-1", - ) - step(ctx) # no raise - mock_record.assert_called_once_with("hash-1", "t1", True) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py deleted file mode 100644 index 8c5ad38a16..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_mount.py +++ /dev/null @@ -1,77 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import patch - -import pytest -from werkzeug.exceptions import Unauthorized - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import CallerMount -from controllers.openapi.auth.strategies import AccountMounter, EndUserMounter -from core.app.entities.app_invoke_entities import InvokeFrom -from libs.oauth_bearer import SubjectType - - -def _ctx(*, subject_type, account_id=None, subject_email=None): - c = Context(required_scope="apps:run") - c.subject_type = subject_type - c.account_id = account_id - c.subject_email = subject_email - c.app = SimpleNamespace(id="app1") - c.tenant = SimpleNamespace(id="t1") - return c - - -@patch("controllers.openapi.auth.strategies._login_as") -@patch("controllers.openapi.auth.strategies.db") -def test_account_mounter(db, login): - account = SimpleNamespace() - db.session.get.return_value = account - ctx = _ctx(subject_type=SubjectType.ACCOUNT, account_id="acc1") - AccountMounter().mount(ctx) - assert ctx.caller is account - assert ctx.caller.current_tenant is ctx.tenant - assert ctx.caller_kind == "account" - login.assert_called_once_with(account) - - -@patch("controllers.openapi.auth.strategies._login_as") -@patch("controllers.openapi.auth.strategies.EndUserService") -def test_end_user_mounter(svc, login): - eu = SimpleNamespace() - svc.get_or_create_end_user_by_type.return_value = eu - ctx = _ctx(subject_type=SubjectType.EXTERNAL_SSO, subject_email="a@x.com") - EndUserMounter().mount(ctx) - svc.get_or_create_end_user_by_type.assert_called_once_with( - InvokeFrom.OPENAPI, - tenant_id="t1", - app_id="app1", - user_id="a@x.com", - ) - assert ctx.caller is eu - assert ctx.caller_kind == "end_user" - - -def test_caller_mount_dispatches_by_subject_type(): - seen = {} - - class Fake: - def __init__(self, st, tag): - self._st, self._tag = st, tag - - def applies_to(self, st): - return st == self._st - - def mount(self, ctx): - seen["who"] = self._tag - - cm = CallerMount( - Fake(SubjectType.ACCOUNT, "acct"), - Fake(SubjectType.EXTERNAL_SSO, "sso"), - ) - cm(_ctx(subject_type=SubjectType.EXTERNAL_SSO)) - assert seen == {"who": "sso"} - - -def test_caller_mount_raises_when_none_applies(): - with pytest.raises(Unauthorized): - CallerMount()(_ctx(subject_type=SubjectType.ACCOUNT)) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py b/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py deleted file mode 100644 index b4adbacd1e..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_step_scope.py +++ /dev/null @@ -1,25 +0,0 @@ -import pytest -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import ScopeCheck - - -def _ctx(scopes, required): - c = Context(required_scope=required) - c.scopes = frozenset(scopes) - return c - - -def test_scope_check_passes_on_full(): - ScopeCheck()(_ctx({"full"}, "apps:run")) - - -def test_scope_check_passes_on_explicit_match(): - ScopeCheck()(_ctx({"apps:run"}, "apps:run")) - - -def test_scope_check_rejects_when_missing(): - with pytest.raises(Forbidden) as exc: - ScopeCheck()(_ctx({"apps:read"}, "apps:run")) - assert "insufficient_scope" in str(exc.value.description) diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py b/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py deleted file mode 100644 index f3b49b18da..0000000000 --- a/api/tests/unit_tests/controllers/openapi/auth/test_surface_gate.py +++ /dev/null @@ -1,239 +0,0 @@ -"""Surface gate tests. - -The gate has two attachment forms — decorator (`accept_subjects`) and -pipeline step (`SurfaceCheck`) — and both must: -- 403 on mismatched subject type with a canonical-path hint -- emit `openapi.wrong_surface_denied` once with the right payload -- pass-through on match -- raise RuntimeError (not 403) if the auth ContextVar is unset — that's - a wiring bug, not a user-driven failure - -Identity is published via `libs.oauth_bearer.set_auth_ctx` / read with -`try_get_auth_ctx`. Tests wrap the publish in a `_publish_auth_ctx` -context manager so the ContextVar resets even when an assertion fails; -that keeps state from leaking into the next test on the same worker. -""" - -from __future__ import annotations - -import uuid -from collections.abc import Iterator -from contextlib import contextmanager -from datetime import UTC, datetime -from unittest.mock import patch - -import pytest -from flask import Flask -from werkzeug.exceptions import Forbidden - -from controllers.openapi.auth.context import Context -from controllers.openapi.auth.steps import SurfaceCheck -from controllers.openapi.auth.surface_gate import _coerce_subject_type, accept_subjects, check_surface -from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx - - -@contextmanager -def _publish_auth_ctx(ctx: AuthContext) -> Iterator[None]: - token = set_auth_ctx(ctx) - try: - yield - finally: - reset_auth_ctx(token) - - -def _account_ctx() -> AuthContext: - return AuthContext( - subject_type=SubjectType.ACCOUNT, - subject_email="user@example.com", - subject_issuer="dify:account", - account_id=uuid.uuid4(), - client_id="difyctl", - scopes=frozenset({Scope.FULL}), - token_id=uuid.uuid4(), - source="oauth_account", - expires_at=datetime.now(UTC), - token_hash="h1", - verified_tenants={}, - ) - - -def _sso_ctx() -> AuthContext: - return AuthContext( - subject_type=SubjectType.EXTERNAL_SSO, - subject_email="sso@partner.com", - subject_issuer="https://idp.partner.com", - account_id=None, - client_id="difyctl", - scopes=frozenset({Scope.APPS_RUN, Scope.APPS_READ_PERMITTED_EXTERNAL}), - token_id=uuid.uuid4(), - source="oauth_external_sso", - expires_at=datetime.now(UTC), - token_hash="h2", - verified_tenants={}, - ) - - -# --------------------------------------------------------------------------- -# check_surface — shared core -# --------------------------------------------------------------------------- - - -def test_check_surface_passes_when_subject_in_accepted(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_account_ctx()): - check_surface(frozenset({SubjectType.ACCOUNT})) # no raise - - -def test_check_surface_rejects_on_wrong_subject_and_emits_audit(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/permitted-external-apps"), _publish_auth_ctx(_account_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden) as exc: - check_surface(frozenset({SubjectType.EXTERNAL_SSO})) - assert "wrong_surface" in exc.value.description - # canonical-path hint should point at the caller's surface, - # not the surface they were rejected from - assert "/openapi/v1/apps" in exc.value.description - emit.assert_called_once() - kwargs = emit.call_args.kwargs - assert kwargs["subject_type"] == SubjectType.ACCOUNT.value - assert kwargs["attempted_path"] == "/openapi/v1/permitted-external-apps" - assert kwargs["client_id"] == "difyctl" - assert kwargs["token_id"] is not None - - -def test_check_surface_rejects_sso_on_account_surface(): - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"), _publish_auth_ctx(_sso_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden): - check_surface(frozenset({SubjectType.ACCOUNT})) - kwargs = emit.call_args.kwargs - assert kwargs["subject_type"] == SubjectType.EXTERNAL_SSO.value - - -def test_check_surface_runtime_error_when_auth_ctx_missing(): - """Missing auth ContextVar means the bearer layer didn't run — wiring - bug, not a user-driven failure. Surface as RuntimeError (loud) so a - future refactor doesn't accidentally let a route skip authentication - and return a 403 that looks identical to a legitimate wrong-surface - deny. - """ - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps"): - with pytest.raises(RuntimeError): - check_surface(frozenset({SubjectType.ACCOUNT})) - - -# --------------------------------------------------------------------------- -# @accept_subjects — decorator form -# --------------------------------------------------------------------------- - - -def _make_app() -> Flask: - app = Flask(__name__) - - @app.route("/account-only") - @accept_subjects(SubjectType.ACCOUNT) - def _account_only(): - return "ok" - - @app.route("/external-only") - @accept_subjects(SubjectType.EXTERNAL_SSO) - def _external_only(): - return "ok" - - return app - - -def test_accept_subjects_decorator_passes_on_match(): - app = _make_app() - with app.test_request_context("/account-only"), _publish_auth_ctx(_account_ctx()): - # Re-route through the decorated function by reaching for view_function - view = app.view_functions["_account_only"] - assert view() == "ok" - - -def test_accept_subjects_decorator_403_on_miss(): - app = _make_app() - with app.test_request_context("/external-only"), _publish_auth_ctx(_account_ctx()): - view = app.view_functions["_external_only"] - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface"): - with pytest.raises(Forbidden): - view() - - -# --------------------------------------------------------------------------- -# SurfaceCheck — pipeline step form -# --------------------------------------------------------------------------- - - -def _pipeline_ctx() -> Context: - # SurfaceCheck reads ``request.path`` from Flask's global request — set up - # via ``app.test_request_context`` in the calling tests — not from Context. - return Context(required_scope=Scope.APPS_RUN) - - -def test_surface_check_passes_on_match(): - step = SurfaceCheck(accepted=frozenset({SubjectType.ACCOUNT})) - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()): - step(_pipeline_ctx()) # no raise - - -def test_surface_check_rejects_on_miss_and_emits_audit(): - step = SurfaceCheck(accepted=frozenset({SubjectType.EXTERNAL_SSO})) - app = Flask(__name__) - with app.test_request_context("/openapi/v1/apps/x/run"), _publish_auth_ctx(_account_ctx()): - with patch("controllers.openapi.auth.surface_gate.emit_wrong_surface") as emit: - with pytest.raises(Forbidden): - step(_pipeline_ctx()) - emit.assert_called_once() - - -# --------------------------------------------------------------------------- -# _coerce_subject_type — normalises whatever sat on ctx.subject_type -# --------------------------------------------------------------------------- -# -# The gate reads `ctx.subject_type` via `getattr(..., None)`, so the value -# could be a real enum (happy path), a raw string (e.g. rehydrated from a -# dict-shaped context), `None` (attribute missing), or something unexpected -# from a buggy upstream. The coercer must collapse all of that to -# `SubjectType | None` so `check_surface` can do a clean set-membership -# check and emit a clean audit payload. - - -def test_coerce_subject_type_returns_none_for_none(): - assert _coerce_subject_type(None) is None - - -def test_coerce_subject_type_returns_enum_instance_unchanged(): - # Identity matters: we don't want to round-trip through the string - # constructor for an already-valid enum. - assert _coerce_subject_type(SubjectType.ACCOUNT) is SubjectType.ACCOUNT - assert _coerce_subject_type(SubjectType.EXTERNAL_SSO) is SubjectType.EXTERNAL_SSO - - -@pytest.mark.parametrize( - ("raw", "expected"), - [ - ("account", SubjectType.ACCOUNT), - ("external_sso", SubjectType.EXTERNAL_SSO), - ], -) -def test_coerce_subject_type_parses_known_strings(raw: str, expected: SubjectType): - assert _coerce_subject_type(raw) is expected - - -def test_coerce_subject_type_raises_on_unknown_string(): - # Unknown strings reach `SubjectType(raw)` which raises ValueError. - # We surface that loudly rather than silently returning None, because - # a string that *looks* like a subject type but isn't is almost - # certainly an upstream bug worth catching. - with pytest.raises(ValueError): - _coerce_subject_type("not_a_subject") - - -@pytest.mark.parametrize("raw", [123, 1.5, b"account", object(), ["account"], {"account"}]) -def test_coerce_subject_type_returns_none_for_non_string_non_enum(raw: object): - assert _coerce_subject_type(raw) is None diff --git a/api/tests/unit_tests/controllers/openapi/auth/test_verify.py b/api/tests/unit_tests/controllers/openapi/auth/test_verify.py new file mode 100644 index 0000000000..c7e0cd7402 --- /dev/null +++ b/api/tests/unit_tests/controllers/openapi/auth/test_verify.py @@ -0,0 +1,142 @@ +import uuid +from unittest.mock import MagicMock, patch + +import pytest +from werkzeug.exceptions import Forbidden, Unauthorized + +from controllers.openapi.auth.data import AuthData +from controllers.openapi.auth.verify import ( + check_acl, + check_app_access, + check_membership, + check_private_app_permission, + check_scope, +) +from libs.oauth_bearer import Scope, TokenType +from models.account import Tenant +from models.model import App +from services.enterprise.enterprise_service import WebAppAccessMode + + +def _data(**kwargs) -> AuthData: + defaults: dict = {"token_type": TokenType.OAUTH_ACCOUNT, "token_hash": "hash", "scopes": frozenset({Scope.FULL})} + defaults.update(kwargs) + return AuthData(**defaults) + + +def test_check_scope_passes_when_required_is_none(): + check_scope(_data(required_scope=None)) + + +def test_check_scope_passes_when_full_in_scopes(): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.FULL}))) + + +def test_check_scope_passes_when_exact_scope_present(): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_RUN}))) + + +def test_check_scope_raises_forbidden_when_scope_missing(): + with pytest.raises(Forbidden, match="insufficient_scope"): + check_scope(_data(required_scope=Scope.APPS_RUN, scopes=frozenset({Scope.APPS_READ}))) + + +def test_check_membership_raises_unauthorized_when_tenant_none(): + with pytest.raises(Unauthorized): + check_membership(_data(tenant=None)) + + +def test_check_membership_calls_check_workspace_membership(): + tenant = MagicMock(spec=Tenant) + tenant.id = "tenant-1" + data = _data( + account_id=uuid.uuid4(), + token_hash="myhash", + tenants={"tenant-1": True}, + tenant=tenant, + ) + with patch("controllers.openapi.auth.verify.check_workspace_membership") as mock_cwm: + check_membership(data) + mock_cwm.assert_called_once_with( + account_id=data.account_id, + tenant_id="tenant-1", + token_hash="myhash", + membership_cache=data.tenants, + ) + + +def test_check_app_access_passes_when_tenant_none(): + check_app_access(_data(tenant=None)) + + +def test_check_app_access_passes_when_member(): + tenant = MagicMock(spec=Tenant) + tenant.id = "t1" + data = _data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=True): + check_app_access(data) + + +def test_check_app_access_raises_when_not_member(): + tenant = MagicMock(spec=Tenant) + tenant.id = "t1" + data = _data(account_id=uuid.uuid4(), tenant=tenant) + with patch("controllers.openapi.auth.verify.TenantService.account_belongs_to_tenant", return_value=False): + with pytest.raises(Forbidden, match="subject_no_app_access"): + check_app_access(data) + + +def test_check_acl_raises_when_app_or_mode_missing(): + with pytest.raises(Forbidden): + check_acl(_data(app=None, app_access_mode=None)) + + +def test_check_acl_account_allowed_for_public(): + app = MagicMock(spec=App) + data = _data(token_type=TokenType.OAUTH_ACCOUNT, app=app, app_access_mode=WebAppAccessMode.PUBLIC) + check_acl(data) + + +def test_check_acl_external_sso_blocked_for_private(): + app = MagicMock(spec=App) + data = _data( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + app=app, + app_access_mode=WebAppAccessMode.PRIVATE, + ) + with pytest.raises(Forbidden, match="subject_not_allowed_for_access_mode"): + check_acl(data) + + +def test_check_acl_external_sso_allowed_for_sso_verified(): + app = MagicMock(spec=App) + data = _data( + token_type=TokenType.OAUTH_EXTERNAL_SSO, + app=app, + app_access_mode=WebAppAccessMode.SSO_VERIFIED, + ) + check_acl(data) + + +def test_check_private_app_permission_raises_when_app_none(): + with pytest.raises(Forbidden): + check_private_app_permission(_data(app=None)) + + +def test_check_private_app_permission_raises_when_user_not_allowed(): + app = MagicMock(spec=App) + app.id = "app-1" + data = _data(account_id=uuid.uuid4(), app=app) + target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp" + with patch(target, return_value=False): + with pytest.raises(Forbidden, match="user_not_allowed_for_private_app"): + check_private_app_permission(data) + + +def test_check_private_app_permission_passes_when_allowed(): + app = MagicMock(spec=App) + app.id = "app-1" + data = _data(account_id=uuid.uuid4(), app=app) + target = "controllers.openapi.auth.verify.EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp" + with patch(target, return_value=True): + check_private_app_permission(data) diff --git a/api/tests/unit_tests/controllers/openapi/conftest.py b/api/tests/unit_tests/controllers/openapi/conftest.py index 38dae79a11..18b3b2fabf 100644 --- a/api/tests/unit_tests/controllers/openapi/conftest.py +++ b/api/tests/unit_tests/controllers/openapi/conftest.py @@ -1,20 +1,36 @@ +import uuid + import pytest from flask import Flask from controllers.openapi import bp as openapi_bp -from controllers.openapi.auth.pipeline import Pipeline +from controllers.openapi.auth.data import AuthData +from controllers.openapi.auth.pipeline import PipelineRouter +from libs.oauth_bearer import Scope, TokenType + + +def _stub_execute(self, args, kwargs, view, *, scope=None, allowed_token_types=None, edition=None): + """Bypass all auth logic; inject minimal AuthData and call the view directly.""" + kwargs["auth_data"] = AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + token_id=uuid.uuid4(), + scopes=frozenset({Scope.FULL}), + required_scope=scope, + ) + return view(*args, **kwargs) @pytest.fixture def bypass_pipeline(monkeypatch): - """Stub Pipeline.run so endpoint decoration does not invoke real auth. + """Stub PipelineRouter._execute so endpoints skip real auth at request time. - Module-level @OAUTH_BEARER_PIPELINE.guard(...) captures the real - pipeline at import time; mocking the module attribute does not undo - that. Patching Pipeline.run on the class is the bypass that actually - works. + Module-level @auth_router.guard(...) captures the real router at import + time — patching guard itself does nothing. Patching _execute on the class + is the seam that fires at request time. """ - monkeypatch.setattr(Pipeline, "run", lambda self, ctx: None) + monkeypatch.setattr(PipelineRouter, "_execute", _stub_execute) @pytest.fixture diff --git a/api/tests/unit_tests/controllers/openapi/test_account.py b/api/tests/unit_tests/controllers/openapi/test_account.py index 15624305a3..f73dc5c0cc 100644 --- a/api/tests/unit_tests/controllers/openapi/test_account.py +++ b/api/tests/unit_tests/controllers/openapi/test_account.py @@ -86,7 +86,7 @@ def test_subject_match_for_account_filters_by_account_id(): """Account subject scopes queries via account_id.""" import uuid as _uuid - from libs.oauth_bearer import AuthContext, SubjectType + from libs.oauth_bearer import AuthContext, SubjectType, TokenType from services.oauth_device_flow import subject_match_clauses aid = _uuid.uuid4() @@ -98,7 +98,7 @@ def test_subject_match_for_account_filters_by_account_id(): client_id="difyctl", scopes=frozenset({"full"}), token_id=_uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, @@ -116,7 +116,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer(): """ import uuid as _uuid - from libs.oauth_bearer import AuthContext, SubjectType + from libs.oauth_bearer import AuthContext, SubjectType, TokenType from services.oauth_device_flow import subject_match_clauses ctx = AuthContext( @@ -127,7 +127,7 @@ def test_subject_match_for_external_sso_filters_by_email_and_issuer(): client_id="difyctl", scopes=frozenset({"apps:run"}), token_id=_uuid.uuid4(), - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants={}, diff --git a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py index 8db5033704..8933533af0 100644 --- a/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py +++ b/api/tests/unit_tests/controllers/openapi/test_app_run_streaming.py @@ -57,7 +57,11 @@ def test_stop_task_endpoint_registered(openapi_app): def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, monkeypatch): + import uuid + from controllers.openapi.app_run import AppRunTaskStopApi + from controllers.openapi.auth.data import AuthData + from libs.oauth_bearer import Scope, TokenType queue_mock = Mock() graph_mock = Mock() @@ -69,15 +73,23 @@ def test_stop_task_calls_queue_manager_and_graph_engine(app, bypass_pipeline, mo monkeypatch.setattr(run_module, "GraphEngineManager", graph_mock) monkeypatch.setattr(run_module, "redis_client", object()) + auth_data = AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=SimpleNamespace(id="app-1", tenant_id="t-1"), + caller=SimpleNamespace(id="acct-1"), + caller_kind="account", + ) + api = AppRunTaskStopApi() with app.test_request_context("/openapi/v1/apps/app-1/tasks/task-1/stop", method="POST"): result = api.post.__wrapped__( api, app_id="app-1", task_id="task-1", - app_model=SimpleNamespace(id="app-1", tenant_id="t-1"), - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=auth_data, ) queue_mock.set_stop_flag_no_user_check.assert_called_once_with("task-1") diff --git a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py index 42ecfc5eb2..52fd0f89d5 100644 --- a/api/tests/unit_tests/controllers/openapi/test_human_input_form.py +++ b/api/tests/unit_tests/controllers/openapi/test_human_input_form.py @@ -4,6 +4,7 @@ from __future__ import annotations import json import sys +import uuid from datetime import UTC, datetime from types import SimpleNamespace from unittest.mock import Mock @@ -11,9 +12,23 @@ from unittest.mock import Mock import pytest from werkzeug.exceptions import NotFound +from controllers.openapi.auth.data import AuthData +from libs.oauth_bearer import Scope, TokenType from models.human_input import RecipientType +def _make_auth_data(app_model, caller, caller_kind): + return AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=app_model, + caller=caller, + caller_kind=caller_kind, + ) + + class TestOpenApiHumanInputFormGet: def test_get_success(self, app, bypass_pipeline, monkeypatch): from controllers.openapi.human_input_form import OpenApiWorkflowHumanInputFormApi @@ -43,15 +58,14 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): resp = api.get.__wrapped__( api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) payload = json.loads(resp.get_data(as_text=True)) @@ -71,6 +85,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/bad"): with pytest.raises(NotFound): @@ -78,9 +93,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="bad", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_get_form_wrong_app(self, app, bypass_pipeline, monkeypatch): @@ -97,6 +110,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): with pytest.raises(NotFound): @@ -104,9 +118,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_get_form_wrong_surface(self, app, bypass_pipeline, monkeypatch): @@ -126,6 +138,7 @@ class TestOpenApiHumanInputFormGet: api = OpenApiWorkflowHumanInputFormApi() app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/form/human_input/tok-1"): with pytest.raises(NotFound): @@ -133,9 +146,7 @@ class TestOpenApiHumanInputFormGet: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) @@ -172,9 +183,7 @@ class TestOpenApiHumanInputFormPost: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=caller, - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) service_mock.submit_form_by_token.assert_called_once_with( @@ -211,9 +220,7 @@ class TestOpenApiHumanInputFormPost: api, app_id="app-1", form_token="tok-1", - app_model=app_model, - caller=caller, - caller_kind="end_user", + auth_data=_make_auth_data(app_model, caller, "end_user"), ) service_mock.submit_form_by_token.assert_called_once_with( diff --git a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py index 78b85460b3..78f2d0f20d 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py +++ b/api/tests/unit_tests/controllers/openapi/test_workflow_events_openapi.py @@ -3,15 +3,30 @@ from __future__ import annotations import sys +import uuid from types import SimpleNamespace from unittest.mock import Mock import pytest from werkzeug.exceptions import NotFound +from controllers.openapi.auth.data import AuthData +from libs.oauth_bearer import Scope, TokenType from models.enums import CreatorUserRole +def _make_auth_data(app_model, caller, caller_kind): + return AuthData.model_construct( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=uuid.uuid4(), + token_hash="test", + scopes=frozenset({Scope.FULL}), + app=app_model, + caller=caller, + caller_kind=caller_kind, + ) + + def _make_workflow_run( *, app_id="app-1", @@ -50,6 +65,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): with pytest.raises(NotFound): @@ -57,9 +73,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_not_found_when_run_belongs_to_different_app(self, app, bypass_pipeline, monkeypatch): @@ -77,6 +91,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): with pytest.raises(NotFound): @@ -84,9 +99,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_account_caller_checks_created_by_account(self, app, bypass_pipeline, monkeypatch): @@ -115,6 +128,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -123,9 +137,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) assert resp.mimetype == "text/event-stream" @@ -143,6 +155,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -151,9 +164,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) def test_end_user_caller_checks_created_by_end_user(self, app, bypass_pipeline, monkeypatch): @@ -179,6 +190,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="eu-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -186,9 +198,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="eu-1"), - caller_kind="end_user", + auth_data=_make_auth_data(app_model, caller, "end_user"), ) assert resp.mimetype == "text/event-stream" @@ -222,6 +232,7 @@ class TestOpenApiWorkflowEventsApi: from models.model import AppMode app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1", mode=AppMode.WORKFLOW) + caller = SimpleNamespace(id="acct-1") api = self._get_api() with app.test_request_context("/openapi/v1/apps/app-1/tasks/wf-run-1/events"): @@ -229,9 +240,7 @@ class TestOpenApiWorkflowEventsApi: api, app_id="app-1", task_id="wf-run-1", - app_model=app_model, - caller=SimpleNamespace(id="acct-1"), - caller_kind="account", + auth_data=_make_auth_data(app_model, caller, "account"), ) assert resp.mimetype == "text/event-stream" chunks = list(resp.response) diff --git a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py index 970b5661e5..6e32487348 100644 --- a/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py +++ b/api/tests/unit_tests/controllers/openapi/test_workspaces_members.py @@ -38,7 +38,7 @@ from controllers.openapi.workspaces import ( WorkspaceMembersApi, WorkspaceSwitchApi, ) -from libs.oauth_bearer import AuthContext, Scope, SubjectType, reset_auth_ctx, set_auth_ctx +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, reset_auth_ctx, set_auth_ctx from models.account import AccountStatus, TenantAccountRole from services.errors.account import ( AccountAlreadyInTenantError, @@ -97,13 +97,25 @@ def _auth_ctx(account_id: uuid.UUID | None = None) -> AuthContext: client_id="difyctl", scopes=frozenset({Scope.FULL}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=datetime.now(UTC), token_hash="h", verified_tenants={}, ) +def _auth_data(account_id: uuid.UUID) -> AuthData: + from controllers.openapi.auth.data import AuthData + from libs.oauth_bearer import Scope, TokenType + + return AuthData( + token_type=TokenType.OAUTH_ACCOUNT, + account_id=account_id, + token_hash="testhash", + scopes=frozenset({Scope.FULL}), + ) + + def _account(account_id: str = "acct-1", email: str = "u@example.com") -> SimpleNamespace: return SimpleNamespace( id=account_id, @@ -256,7 +268,7 @@ def test_switch_returns_workspace_detail_with_current_true(app, bypass_pipeline, with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) assert status == 200 assert body["id"] == ws_id @@ -284,7 +296,7 @@ def test_switch_404s_when_service_raises_account_not_link_tenant(app, bypass_pip with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/switch", method="POST"): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(NotFound): - api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) # --------------------------------------------------------------------------- @@ -318,7 +330,7 @@ def test_members_list_returns_normalized_rows(app, bypass_pipeline, monkeypatch) with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) assert status == 200 assert body["page"] == 1 @@ -360,7 +372,7 @@ def test_members_list_paginates_with_query_params(app, bypass_pipeline, monkeypa with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?page=2&limit=2"): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + body, status = api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) assert status == 200 assert body["page"] == 2 @@ -383,7 +395,7 @@ def test_members_list_rejects_unknown_query_param(app, bypass_pipeline, monkeypa with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members?pg=2"): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(BadRequest): - api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) # --------------------------------------------------------------------------- @@ -421,7 +433,7 @@ def test_invite_happy_path_returns_invite_url_and_member_id(app, bypass_pipeline content_type="application/json", ): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) assert status == 201 assert body["result"] == "success" @@ -506,7 +518,7 @@ def test_invite_blocked_by_saas_members_cap(app, bypass_pipeline, monkeypatch): with _invite_request(app, ws_id, acct_id): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(Forbidden) as exc_info: - api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) body = exc_info.value.response.json assert body["code"] == "members.limit_exceeded" @@ -552,7 +564,7 @@ def test_invite_blocked_by_ee_workspace_members_license(app, bypass_pipeline, mo with _invite_request(app, ws_id, acct_id): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(Forbidden) as exc_info: - api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) body = exc_info.value.response.json assert body["code"] == "workspace_members.license_exceeded" @@ -591,7 +603,7 @@ def test_invite_ce_passes_when_both_caps_disabled(app, bypass_pipeline, monkeypa with _invite_request(app, ws_id, acct_id): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + body, status = api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) assert status == 201 assert body["email"] == "new@example.com" @@ -620,7 +632,7 @@ def test_invite_400_when_already_in_tenant(app, bypass_pipeline, monkeypatch): ): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(BadRequest): - api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) # --------------------------------------------------------------------------- @@ -653,10 +665,8 @@ def test_delete_member_happy_path(app, bypass_pipeline, monkeypatch): method="DELETE", ): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.delete.__wrapped__.__wrapped__.__wrapped__( - api, - workspace_id=ws_id, - member_id=member_id, + body, status = api.delete.__wrapped__.__wrapped__( + api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id) ) assert status == 200 @@ -697,10 +707,11 @@ def test_delete_member_exception_mapping(app, bypass_pipeline, monkeypatch, exc, ): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(expected): - api.delete.__wrapped__.__wrapped__.__wrapped__( + api.delete.__wrapped__.__wrapped__( api, workspace_id=ws_id, member_id=member_id, + auth_data=_auth_data(acct_id), ) @@ -723,10 +734,11 @@ def test_delete_member_404_when_member_missing(app, bypass_pipeline, monkeypatch ): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(NotFound): - api.delete.__wrapped__.__wrapped__.__wrapped__( + api.delete.__wrapped__.__wrapped__( api, workspace_id=ws_id, member_id=member_id, + auth_data=_auth_data(acct_id), ) @@ -762,10 +774,8 @@ def test_update_role_happy_path(app, bypass_pipeline, monkeypatch): content_type="application/json", ): _seed(_auth_ctx(account_id=acct_id)) - body, status = api.put.__wrapped__.__wrapped__.__wrapped__( - api, - workspace_id=ws_id, - member_id=member_id, + body, status = api.put.__wrapped__.__wrapped__( + api, workspace_id=ws_id, member_id=member_id, auth_data=_auth_data(acct_id) ) assert status == 200 @@ -810,10 +820,11 @@ def test_update_role_exception_mapping(app, bypass_pipeline, monkeypatch, exc, e ): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(expected): - api.put.__wrapped__.__wrapped__.__wrapped__( + api.put.__wrapped__.__wrapped__( api, workspace_id=ws_id, member_id=member_id, + auth_data=_auth_data(acct_id), ) @@ -847,9 +858,8 @@ def test_non_member_caller_gets_404_on_switch(app, bypass_pipeline, monkeypatch) # Strip only the bearer + surface-gate wrappers; keep the role gate. # Decorator stack (innermost → outermost): # role_gate → accept_subjects → validate_bearer - # So `post.__wrapped__` unwraps validate_bearer; we then unwrap - # accept_subjects to land on the role-gate wrapper. - gated = api.post.__wrapped__.__wrapped__ + # `post.__wrapped__` is now the role-gate wrapper directly (auth_router.guard is the only outer wrapper). + gated = api.post.__wrapped__ with pytest.raises(NotFound): gated(api, workspace_id=ws_id) @@ -881,7 +891,7 @@ def test_load_tenant_rejects_archived_workspace(app, bypass_pipeline, monkeypatc with app.test_request_context(f"/openapi/v1/workspaces/{ws_id}/members"): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(NotFound): - api.get.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.get.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) # --------------------------------------------------------------------------- @@ -915,4 +925,4 @@ def test_invite_400_when_register_error(app, bypass_pipeline, monkeypatch): ): _seed(_auth_ctx(account_id=acct_id)) with pytest.raises(BadRequest): - api.post.__wrapped__.__wrapped__.__wrapped__(api, workspace_id=ws_id) + api.post.__wrapped__.__wrapped__(api, workspace_id=ws_id, auth_data=_auth_data(acct_id)) diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py index dd4304ccb1..d3d2d583f2 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_rate_limit_ordering.py @@ -11,6 +11,7 @@ from libs.oauth_bearer import ( SubjectType, TokenKind, TokenKindRegistry, + TokenType, ) @@ -21,7 +22,7 @@ def _registry_with_resolver(resolver) -> TokenKindRegistry: prefix="dfoa_", subject_type=SubjectType.ACCOUNT, scopes=frozenset({Scope.FULL}), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=resolver, ) ] @@ -63,7 +64,7 @@ def test_unknown_prefix_raises_generic_invalid_bearer(): prefix="dfoa_", subject_type=SubjectType.ACCOUNT, scopes=frozenset({Scope.FULL}), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, resolver=MagicMock(), ) ] diff --git a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py index 898e4578e6..e8204a6e2e 100644 --- a/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py +++ b/api/tests/unit_tests/libs/test_oauth_bearer_require_scope.py @@ -19,6 +19,7 @@ from libs.oauth_bearer import ( AuthContext, Scope, SubjectType, + TokenType, require_scope, reset_auth_ctx, set_auth_ctx, @@ -50,7 +51,7 @@ def _ctx(scopes) -> AuthContext: client_id="difyctl", scopes=scopes, token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, diff --git a/api/tests/unit_tests/libs/test_workspace_member_helper.py b/api/tests/unit_tests/libs/test_workspace_member_helper.py index 540e19ad9e..f4933e7f59 100644 --- a/api/tests/unit_tests/libs/test_workspace_member_helper.py +++ b/api/tests/unit_tests/libs/test_workspace_member_helper.py @@ -8,7 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from werkzeug.exceptions import Forbidden -from libs.oauth_bearer import AuthContext, Scope, SubjectType, require_workspace_member +from libs.oauth_bearer import AuthContext, Scope, SubjectType, TokenType, require_workspace_member def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> AuthContext: @@ -20,7 +20,7 @@ def _ctx(verified: dict[str, bool] | None = None, *, account: bool = True) -> Au client_id="difyctl", scopes=frozenset({Scope.FULL}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT if account else TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants=dict(verified or {}), diff --git a/api/tests/unit_tests/services/test_oauth_device_flow.py b/api/tests/unit_tests/services/test_oauth_device_flow.py index b2e95c93a3..fcb3f29a76 100644 --- a/api/tests/unit_tests/services/test_oauth_device_flow.py +++ b/api/tests/unit_tests/services/test_oauth_device_flow.py @@ -3,7 +3,7 @@ from __future__ import annotations import uuid from unittest.mock import MagicMock -from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType +from libs.oauth_bearer import TOKEN_CACHE_KEY_FMT, AuthContext, SubjectType, TokenType from services.oauth_device_flow import ( list_active_sessions, revoke_oauth_token, @@ -21,7 +21,7 @@ def _account_ctx() -> AuthContext: client_id="difyctl", scopes=frozenset({"full"}), token_id=uuid.uuid4(), - source="oauth_account", + token_type=TokenType.OAUTH_ACCOUNT, expires_at=None, token_hash="h1", verified_tenants={}, @@ -37,7 +37,7 @@ def _sso_ctx() -> AuthContext: client_id="difyctl", scopes=frozenset({"apps:run"}), token_id=uuid.uuid4(), - source="oauth_external_sso", + token_type=TokenType.OAUTH_EXTERNAL_SSO, expires_at=None, token_hash="h1", verified_tenants={},